Commit 42e312e1 authored by Benjamin Murauer's avatar Benjamin Murauer
Browse files

more flexible limiting

parent 84a8e699
......@@ -7,6 +7,7 @@ from typing import List, Optional, Tuple
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from utils import LOGGER
from dbispipeline.base import Loader
......@@ -24,14 +25,25 @@ def _attach(df: pd.DataFrame, y: np.array) -> Tuple[pd.DataFrame, str]:
def _limit(
dataset_part: Tuple[pd.DataFrame, np.array],
remaining_targets: List[str],
max_docs_per_target: Optional[int],
dataset_part: Tuple[pd.DataFrame, np.array],
remaining_targets: List[str],
max_docs_per_target: Optional[int],
) -> Tuple[pd.DataFrame, np.array]:
df, key = _attach(dataset_part[0], dataset_part[1])
sub_df = df[df[key].isin(remaining_targets)]
groups = sub_df.groupby(key)
min_population = groups.count()['text_raw'].min()
if max_docs_per_target:
sub_df = sub_df.groupby(key).sample(max_docs_per_target)
if min_population > max_docs_per_target:
sub_df = sub_df.groupby(key).sample(max_docs_per_target)
else:
LOGGER.warn(
'Not limiting max_docs_per_target to %d '
'population too small (%d)',
max_docs_per_target,
min_population,
)
return sub_df.drop(columns=key), sub_df[key].values
......@@ -44,8 +56,12 @@ class CrossValidatedSplitLoader(Loader):
parameter.
"""
def __init__(self, n_splits: int = 5, max_targets: int = None,
max_docs_per_target: int = None):
def __init__(
self,
n_splits: int = 5,
max_targets: int = None,
max_docs_per_target: int = None,
):
"""
Initialize the loader.
......@@ -94,10 +110,10 @@ class CrossValidatedSplitLoader(Loader):
else:
splits = []
for train_idx, test_idx in all_splits:
df_train = pd.DataFrame(
dict(idx=train_idx, y=x[key][train_idx]))
df_train = df_train.groupby('y').sample(
self.max_docs_per_target)
df_train = pd.DataFrame(dict(idx=train_idx,
y=x[key][train_idx]))
df_train = df_train.groupby('y')\
.sample(self.max_docs_per_target)
splits.append((df_train.idx.values, test_idx))
return x.drop(columns=[key]), x[key].values, splits
......@@ -110,7 +126,6 @@ class CrossValidatedSplitLoader(Loader):
A tuple of x, y, splits. The splits are something that can be
passed to the GridSearchCV object as the 'cv' parameter.
"""
pass
@property
def configuration(self) -> dict:
......@@ -155,15 +170,23 @@ class TrainTestSplitLoader(Loader):
"""
train, test = self.get_train_data(), self.get_test_data()
all_targets = set(train[1])
if self.max_targets:
if self.max_targets and len(all_targets) > self.max_targets:
selected_targets = random.sample(all_targets, self.max_targets)
elif self.max_targets:
LOGGER.warn(
'Not limiting max_authors to %d, population too small (%d)',
self.max_targets,
len(all_targets),
)
else:
selected_targets = list(all_targets)
train = _limit(train, selected_targets, self.max_docs_per_target)
train = _limit(train, selected_targets,
self.max_docs_per_target)
test = _limit(test, selected_targets, None) # don't limit test data
train_idx = list(range(train[0].shape[0]))
test_idx = list(
range(train[0].shape[0], train[0].shape[0] + test[0].shape[0]))
range(train[0].shape[0],
train[0].shape[0] + test[0].shape[0]))
splits = [(train_idx, test_idx)]
df = pd.concat([train[0], test[0]])
y = np.concatenate([train[1], test[1]])
......@@ -177,7 +200,6 @@ class TrainTestSplitLoader(Loader):
Returns:
A tuple of training data in form of [DataFrame, np.Array]
"""
pass
@abstractmethod
def get_test_data(self) -> Tuple[pd.DataFrame, np.array]:
......@@ -187,7 +209,6 @@ class TrainTestSplitLoader(Loader):
Returns:
A tuple of training data in form of [DataFrame, np.Array]
"""
pass
@property
def configuration(self) -> dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment