Commit fcfb4095 authored by Benjamin Murauer's avatar Benjamin Murauer
Browse files

added load_dataframe comfort function

parent 0b3cf790
Pipeline #39000 passed with stage
in 3 minutes and 2 seconds
......@@ -4,6 +4,7 @@ import pandas as pd
from .db import DB
from .db import DbModel
from .utils import LOGGER
from .utils import load_project_config
......@@ -248,3 +249,78 @@ def extract_gridsearch_parameters(
use_prefix=prefix_parameter_names))
result_rows.append(result_row)
return pd.DataFrame(result_rows)
def load_dataframe(allow_multiple_git_ids=False, allow_dirty_rows=False,
**query_filters):
"""
A flexible wrapper for queries.
Args:
allow_multiple_git_ids (bool): If false, raises an error when loading
results with different git commit ids
allow_dirty_rows (bool): If false, raises an error when loading rows
resulting from a dirty git state
**query_filters: key-value restrictions on the DB query. The values
can have three different representations:
- single vaules, then they will be matched with ==
- comma separated values, then they will be matched with 'in (...)'
- values separated with ' to ', then they will be matched with >=
and <= if one of the two borders is an asterisk, only the other
border will be used. Couldn't use the dash, as it is needed in
date-based queries.
Returns:
A dataframe resulting from the query, and an object containing
additional information of the git and row ids used in this result, and
a list of all scores that were found in the result.
Examples:
- load a single git commit:
df, info = load_dataframe(git_commit_id='2aace91d317694a08...')
- load all results between two row ids
df, info = load_dataframe(id="1200 to 1300")
- specifiy multiple query filters
df, info = load_dataframe(
git_commit_id="...",
sourcefile="plan1.py,plan2.py",
)
"""
session = DB.session()
query = session.query(DbModel)
for field, value in query_filters.items():
if not value:
continue
if ' to ' in value:
min_val, max_val = value.split(' to ')
if min_val != '*':
query = query.filter(getattr(DbModel, field) >= min_val)
if max_val != '*':
query = query.filter(getattr(DbModel, field) <= max_val)
elif ',' in value:
query = query.filter(getattr(DbModel, field).in_(value.split(',')))
else:
query = query.filter(getattr(DbModel, field) == value)
df = rows_to_dataframe(query, allow_multiple_git_ids, allow_dirty_rows)
LOGGER.debug('loaded %s raw rows from db', df.shape[0])
extra_info = {}
scores = None
for _, row in df.iterrows():
row_scores = set([k for k in row['outcome']['cv_results'].keys()
if k.startswith('mean_test_')])
if scores is None:
scores = row_scores
else:
if scores - row_scores or row_scores - scores:
raise ValueError(
'you have loaded rows that have different scoring fields, '
f'which is not supported: {scores} vs. {row_scores}')
LOGGER.debug('extracted scores: %s', scores)
extra_info['scores'] = scores
extra_info['git_ids'] = set(df['git_commit_id'].values)
extra_info['row_ids'] = set(df['id'].values)
df = extract_gridsearch_parameters(df, scores)
return df, extra_info
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment