Commit 6fca1257 authored by Benjamin Murauer's avatar Benjamin Murauer
Browse files

allow load_dataframe to only extract cv results if these are present in outcome column

parent ecfa29c1
Pipeline #41298 passed with stage
in 2 minutes and 3 seconds
......@@ -307,7 +307,12 @@ def load_dataframe(allow_multiple_git_ids=False, allow_dirty_rows=False,
LOGGER.debug('loaded %s raw rows from db', df.shape[0])
extra_info = {}
scores = None
has_cv_results = False
for _, row in df.iterrows():
if 'cv_results' in row:
has_cv_results = True
row_scores = set([k for k in row['outcome']['cv_results'].keys()
if k.startswith('mean_test_')])
if scores is None:
......@@ -317,10 +322,11 @@ def load_dataframe(allow_multiple_git_ids=False, allow_dirty_rows=False,
raise ValueError(
'you have loaded rows that have different scoring fields, '
f'which is not supported: {scores} vs. {row_scores}')
LOGGER.debug('extracted scores: %s', scores)
LOGGER.debug('extracted scores: %s', scores)
extra_info['scores'] = scores
extra_info['git_ids'] = set(df['git_commit_id'].values)
extra_info['row_ids'] = set(df['id'].values)
df = extract_gridsearch_parameters(df, scores)
if has_cv_results:
df = extract_gridsearch_parameters(df, scores)
return df, extra_info
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment