Commit 324af920 authored by Benjamin Murauer's avatar Benjamin Murauer
Browse files

Merge branch '55-improve-execution-time-of-get_results_as_dataframe' into 'master'

Resolve "Improve execution time of get_results_as_dataframe"

Closes #55

See merge request dbis/software/dbispipeline!36
parents 9d607791 a67c2fc1
Pipeline #42163 passed with stage
in 2 minutes and 29 seconds
......@@ -93,9 +93,9 @@ install: clean ## install the package to the active Python's site-packages
format: ## formats the code
yapf -i -r src
yapf -i -r tests
isort src
isort tests
bandit: ## static code checking to find security issues in code
bandit -r src
bandit -r tests
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta:__legacy__"
"""Module containing tools to run result analytics."""
from typing import List
import numpy as np
import pandas as pd
from .db import DB
from .db import DbModel
from .utils import LOGGER
from .utils import SECTION_DATABASE
from .utils import SECTION_PROJECT
from .utils import load_project_config
......@@ -13,22 +17,75 @@ def _extract_metric_results(outcome, requested_metric):
lambda row: row.apply(lambda value: value[requested_metric]))
def get_results_as_dataframe(project_name,
table_name='results',
filter_git_dirty=True):
"""Returns the results stored in the database as a pandas dataframe.
def get_results_as_dataframe(project_name: str = None,
table_name: str = None,
filter_git_dirty=True,
date_filter: str = None,
id_filter: str = None,
additional_conditions: List[str] = None,
columns: List[str] = None) -> pd.DataFrame:
r"""Returns the results stored in the database as a pandas dataframe.
Args:
project_name: the project name to fetch results.
table_name: the name of the reults table.
filter_git_dirty: defines if dirty commits are filterd.
project_name (str, optional): the project name to fetch results. It
defaults to the project name given in the config.
table_name (str, optional): the name of the result table. It defaults
to the table name specified in the config.
filter_git_dirty (bool, optional): defines if dirty commits are
filtered.
date_filter (str, optional): filter by date as a string.
E.g. "> \'2021-01-01\'"
id_filter (str, optional): filter by id. E.g. "= 42" or a comma
separated list of ids.
additional_conditions (List[str], optional): a list of strings that
gets added to the WHERE clause using AND to combine it with other
filters.
columns (List[str], optional): a list of columns that should be
returned. None equals to all.
Returns:
pd.DataFrame: the result as a dataframe.
"""
results = pd.read_sql_table(table_name=table_name, con=DB.engine)
if project_name is None or table_name is None:
config = load_project_config()
if project_name is None:
project_name = config[SECTION_PROJECT]['name']
if table_name is None:
table_name = config[SECTION_DATABASE]['result_table']
if columns is None:
columns = '*'
else:
columns = ', '.join(columns)
sql = 'SELECT %s FROM %s' % (columns, table_name)
conditions = []
if project_name:
conditions.append('project_name LIKE \'%s\'' % project_name)
if id_filter:
if ',' in id_filter and 'in' not in id_filter.lower():
id_filter = 'IN(' + id_filter + ')'
conditions.append('id %s' % id_filter)
if filter_git_dirty:
results = results[results['git_is_dirty'] == False] # noqa: E712
conditions.append('git_is_dirty = FALSE')
if date_filter:
conditions.append('"date" %s' % date_filter)
if additional_conditions:
conditions = conditions + additional_conditions
if len(conditions) > 1:
where_conditions = ' AND '.join(conditions)
else:
where_conditions = conditions[0]
if len(conditions) > 0:
sql = sql + ' WHERE ' + where_conditions
return results[results['project_name'] == project_name]
return pd.read_sql_query(sql, con=DB.engine)
def fetch_by_git_commit_id(git_commit_id):
......@@ -257,7 +314,8 @@ def extract_gridsearch_parameters(
return pd.DataFrame(result_rows)
def load_dataframe(allow_multiple_git_ids=False, allow_dirty_rows=False,
def load_dataframe(allow_multiple_git_ids=False,
allow_dirty_rows=False,
**query_filters):
"""
A flexible wrapper for queries.
......@@ -319,8 +377,10 @@ def load_dataframe(allow_multiple_git_ids=False, allow_dirty_rows=False,
if 'cv_results' in row:
has_cv_results = True
row_scores = set([k for k in row['outcome']['cv_results'].keys()
if k.startswith('mean_test_')])
row_scores = set([
k for k in row['outcome']['cv_results'].keys()
if k.startswith('mean_test_')
])
if scores is None:
scores = row_scores
else:
......
"""Tool to manage data."""
import os
from logzero import logger
import yaml
from dbispipeline import utils
from logzero import logger
LINK_CONFIG_FILE = 'data/links.yaml'
......
......@@ -10,7 +10,11 @@ from dbispipeline.utils import LOGGER
class DataframeLoader(Loader):
"""A generic one-in-all solution."""
def __init__(self, df_path, features=None, targets=None, filters=None,
def __init__(self,
df_path,
features=None,
targets=None,
filters=None,
**extra_args):
"""
Loads the data from a file on disk.
......
"""Module containing DB connectors."""
from sqlalchemy import JSON
from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import Integer
from sqlalchemy import JSON
from sqlalchemy import String
from sqlalchemy import create_engine
from sqlalchemy import func
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy import func
from sqlalchemy.orm import sessionmaker
from .utils import LOGGER
......
......@@ -7,8 +7,8 @@ import warnings
import numpy as np
import pandas as pd
from sklearn import metrics
from sklearn.base import clone
from sklearn import metrics
from sklearn.model_selection import BaseCrossValidator
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import PredefinedSplit
......
......@@ -195,10 +195,10 @@ def print_gridsearch_results(results, include_cv_results=False):
LOGGER.info(f'{score_name}: {best_score}')
# there could be multiple parameter configurations with the same score
best_run_indices = [index
for index, rank
in cv['rank_test_' + score_name].items()
if rank == 1]
best_run_indices = [
index for index, rank in cv['rank_test_' + score_name].items()
if rank == 1
]
best_params = [cv['params'][index] for index in best_run_indices]
LOGGER.info('Best parameters:')
for run_index, param in zip(best_run_indices, best_params):
......
......@@ -394,7 +394,8 @@ def get_job_id():
return os.getpid()
def check_serializability(content, nan_replacement=0.0,
def check_serializability(content,
nan_replacement=0.0,
pos_inf_replacement=0.0,
neg_inf_repacement=0.0):
"""
......@@ -418,19 +419,22 @@ def check_serializability(content, nan_replacement=0.0,
return content
except (ValueError, TypeError):
if type(content) == list or type(content) == tuple:
cleaned = [check_serializability(
x,
nan_replacement=nan_replacement,
pos_inf_replacement=pos_inf_replacement,
neg_inf_repacement=neg_inf_repacement)
for x in content]
cleaned = [
check_serializability(x,
nan_replacement=nan_replacement,
pos_inf_replacement=pos_inf_replacement,
neg_inf_repacement=neg_inf_repacement)
for x in content
]
elif type(content) == dict:
cleaned = {k: check_serializability(
v,
nan_replacement=nan_replacement,
pos_inf_replacement=pos_inf_replacement,
neg_inf_repacement=neg_inf_repacement)
for k, v in content.items()}
cleaned = {
k:
check_serializability(v,
nan_replacement=nan_replacement,
pos_inf_replacement=pos_inf_replacement,
neg_inf_repacement=neg_inf_repacement)
for k, v in content.items()
}
elif callable(content):
cleaned = inspect.getsource(content).strip()
elif type(content) == float and np.isnan(content):
......
"""Tests for the get_results_as_dataframe function."""
import dbispipeline.analytics
from dbispipeline.analytics import SECTION_DATABASE
from dbispipeline.analytics import SECTION_PROJECT
from dbispipeline.analytics import get_results_as_dataframe
from dbispipeline.analytics import load_project_config
class _PandasMock():
def __init__(self, mock_call: dict):
self.mock_call = mock_call
def read_sql_query(self, sql, con):
self.mock_call['sql'] = sql
self.mock_call['con'] = con
self.mock_call['call_count'] += 1
def test_defaults(monkeypatch):
"""Tests if the defaults are set correctly."""
mock_call = {
'call_count': 0,
}
test_config = load_project_config()
test_project = test_config[SECTION_PROJECT]['name']
test_table = test_config[SECTION_DATABASE]['result_table']
expected_sql = ' '.join([
f'SELECT * FROM {test_table}',
f'WHERE project_name LIKE \'{test_project}\'',
'AND git_is_dirty = FALSE',
])
mock_pandas = _PandasMock(mock_call)
with monkeypatch.context() as m:
m.setattr(dbispipeline.analytics, 'pd', mock_pandas)
get_results_as_dataframe()
assert mock_call['call_count'] == 1
assert mock_call['sql'] == expected_sql
def test_project_name(monkeypatch):
"""Tests if the setting the project_name works correctly."""
mock_call = {
'call_count': 0,
}
test_config = load_project_config()
test_project = 'some fancy project'
test_table = test_config[SECTION_DATABASE]['result_table']
expected_sql = ' '.join([
f'SELECT * FROM {test_table}',
f'WHERE project_name LIKE \'{test_project}\'',
'AND git_is_dirty = FALSE',
])
mock_pandas = _PandasMock(mock_call)
with monkeypatch.context() as m:
m.setattr(dbispipeline.analytics, 'pd', mock_pandas)
get_results_as_dataframe(project_name=test_project)
assert mock_call['call_count'] == 1
assert mock_call['sql'] == expected_sql
def test_table_name(monkeypatch):
"""Tests if the setting the table_name works correctly."""
mock_call = {
'call_count': 0,
}
test_config = load_project_config()
test_project = test_config[SECTION_PROJECT]['name']
test_table = 'some_test_table_name'
expected_sql = ' '.join([
f'SELECT * FROM {test_table}',
f'WHERE project_name LIKE \'{test_project}\'',
'AND git_is_dirty = FALSE',
])
mock_pandas = _PandasMock(mock_call)
with monkeypatch.context() as m:
m.setattr(dbispipeline.analytics, 'pd', mock_pandas)
get_results_as_dataframe(table_name=test_table)
assert mock_call['call_count'] == 1
assert mock_call['sql'] == expected_sql
def test_filter_git_dirty(monkeypatch):
"""Tests if the setting the filter_git_dirty works correctly."""
mock_call = {
'call_count': 0,
}
test_config = load_project_config()
test_project = test_config[SECTION_PROJECT]['name']
test_table = test_config[SECTION_DATABASE]['result_table']
expected_sql = ' '.join([
f'SELECT * FROM {test_table}',
f'WHERE project_name LIKE \'{test_project}\'',
])
mock_pandas = _PandasMock(mock_call)
with monkeypatch.context() as m:
m.setattr(dbispipeline.analytics, 'pd', mock_pandas)
get_results_as_dataframe(filter_git_dirty=False)
assert mock_call['call_count'] == 1
assert mock_call['sql'] == expected_sql
def test_date_filter(monkeypatch):
"""Tests if the date_filter is set correctly."""
mock_call = {
'call_count': 0,
}
test_config = load_project_config()
test_project = test_config[SECTION_PROJECT]['name']
test_table = test_config[SECTION_DATABASE]['result_table']
date_filter = '> 2021-01-01'
expected_sql = ' '.join([
f'SELECT * FROM {test_table}',
f'WHERE project_name LIKE \'{test_project}\'',
'AND git_is_dirty = FALSE',
f'AND "date" {date_filter}',
])
mock_pandas = _PandasMock(mock_call)
with monkeypatch.context() as m:
m.setattr(dbispipeline.analytics, 'pd', mock_pandas)
get_results_as_dataframe(date_filter=date_filter)
assert mock_call['call_count'] == 1
assert mock_call['sql'] == expected_sql
def test_id_filter(monkeypatch):
"""Tests if the id_filter is set correctly."""
mock_call = {
'call_count': 0,
}
test_config = load_project_config()
test_project = test_config[SECTION_PROJECT]['name']
test_table = test_config[SECTION_DATABASE]['result_table']
id_filter = '32, 42, 52'
expected_sql = ' '.join([
f'SELECT * FROM {test_table}',
f'WHERE project_name LIKE \'{test_project}\'',
f'AND id IN({id_filter})',
'AND git_is_dirty = FALSE',
])
mock_pandas = _PandasMock(mock_call)
with monkeypatch.context() as m:
m.setattr(dbispipeline.analytics, 'pd', mock_pandas)
get_results_as_dataframe(id_filter=id_filter)
assert mock_call['call_count'] == 1
assert mock_call['sql'] == expected_sql
def test_additional_conditions(monkeypatch):
"""Tests if additional_conditions ar set correctly."""
mock_call = {
'call_count': 0,
}
test_config = load_project_config()
test_project = test_config[SECTION_PROJECT]['name']
test_table = test_config[SECTION_DATABASE]['result_table']
additional_conditions = [
'test > 4',
'test 2 <> "Start 1"',
]
expected_sql = [
f'SELECT * FROM {test_table}',
f'WHERE project_name LIKE \'{test_project}\'',
'AND git_is_dirty = FALSE',
]
expected_sql += [f'AND {entry}' for entry in additional_conditions]
expected_sql = ' '.join(expected_sql)
mock_pandas = _PandasMock(mock_call)
with monkeypatch.context() as m:
m.setattr(dbispipeline.analytics, 'pd', mock_pandas)
get_results_as_dataframe(additional_conditions=additional_conditions)
assert mock_call['call_count'] == 1
assert mock_call['sql'] == expected_sql
def test_columns(monkeypatch):
"""Tests if columns are set correctly."""
mock_call = {
'call_count': 0,
}
test_config = load_project_config()
test_project = test_config[SECTION_PROJECT]['name']
test_table = test_config[SECTION_DATABASE]['result_table']
columns = [
'id',
'"date"',
'outcome',
]
column_string = ', '.join(columns)
expected_sql = ' '.join([
f'SELECT {column_string} FROM {test_table}',
f'WHERE project_name LIKE \'{test_project}\'',
'AND git_is_dirty = FALSE',
])
mock_pandas = _PandasMock(mock_call)
with monkeypatch.context() as m:
m.setattr(dbispipeline.analytics, 'pd', mock_pandas)
get_results_as_dataframe(columns=columns)
assert mock_call['call_count'] == 1
assert mock_call['sql'] == expected_sql
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment