Commit 441bb480 authored by Benjamin Murauer's avatar Benjamin Murauer
Browse files

Merge branch '54-limiting-loader-should-be-able-to-only-limit-the-training-data' into 'master'

Resolve "limiting loader should be able to only limit the training data"

Closes #54

See merge request dbis/software/dbispipeline!35
parents dcb84b6a 46e75fa7
Pipeline #41305 passed with stage
in 1 minute and 49 seconds
......@@ -146,6 +146,7 @@ def limiting(cls):
*args,
max_targets=None,
max_documents_per_target=None,
restrict_training_only=False,
strategy='random',
random_seed=None,
**kwargs):
......@@ -157,6 +158,10 @@ def limiting(cls):
If set to None, all targets remain in the result.
max_documents_per_target: how many documents should be left for
each target class.
restrict_training_only (bool): if set to true, the limitations on
the amount of documents is only being performed on the training
data, but the loader will always use the full testing (and/or
validation) data.
strategy: how should the items that remain in the result be
selected. Valid options are 'first' or 'random' (default).
random_seed: Seed to provide to random.seed(). Ignored if strategy
......@@ -175,44 +180,55 @@ def limiting(cls):
self.max_targets = max_targets
self.max_documents_per_target = max_documents_per_target
self.restrict_training_only = restrict_training_only
self.strategy = strategy
self.random_seed = random_seed
def new_load(self):
pairs = []
if hasattr(cls, 'load_validate'):
pairs.append(cls.load_validate(self))
if hasattr(cls, 'load_test'):
pairs.append(cls.load_test(self))
pairs.append(cls.load_train(self))
elif hasattr(cls, 'load'):
pairs.append(old_load(self))
dataset = {}
if hasattr(cls, 'load_train'):
dataset['train'] = cls.load_train(self)
if hasattr(cls, 'load_validate'):
dataset['validate'] = cls.load_validate(self)
dataset['test'] = cls.load_test(self)
else:
raise ValueError('No loading methods found for this loader!')
dataset['train'] = old_load(self)
# first, restructure (x, y) pairs to {y: [x0, x1, ...]} dict
# for each of the train, test, validate pairs
dicts = []
for pair in pairs:
dicts = {}
for dataset_part in ['train', 'validate', 'test']:
if dataset_part not in dataset:
continue
pair = dataset[dataset_part]
entry = defaultdict(list)
for data, label in zip(pair[0], pair[1]):
if isinstance(label, np.ndarray) or isinstance(label, list):
label = str(label)
entry[label].append(data)
dicts.append(entry)
dicts[dataset_part] = entry
training_targets = set(dicts[-1].keys())
training_targets = set(dicts['train'].keys())
selected_targets = _sample(values=training_targets,
strategy=self.strategy,
sample_limit=self.max_targets)
result = []
for bunch in dicts:
for dataset_part in ['train', 'validate', 'test']:
if dataset_part not in dicts:
continue
bunch = dicts[dataset_part]
bunch_result = [[], []]
for key in selected_targets:
values = _sample(values=bunch[key],
strategy=self.strategy,
sample_limit=self.max_documents_per_target)
if self.restrict_training_only and dataset_part != 'train':
values = bunch[key]
else:
values = _sample(
values=bunch[key],
strategy=self.strategy,
sample_limit=self.max_documents_per_target)
for value in values:
bunch_result[0].append(value)
bunch_result[1].append(key)
......
......@@ -6,94 +6,135 @@ from dbispipeline import base
class VeryDumbLoader(base.Loader):
""" produces a very unrandom, simple classification set."""
"""Produce a very unrandom, simple classification set."""
def __init__(self, n_samples=80, n_classes=4):
"""Initializes Mocking loader."""
self.n_samples = n_samples
self.n_classes = n_classes
@property
def configuration(self):
"""Return configuration for db."""
return {
'n_samples': self.n_samples,
'n_classes': self.n_classes,
}
def load(self):
"""Load data."""
samples = list(range(self.n_samples))
classes = [x % self.n_classes for x in range(self.n_samples)]
return samples, classes
class MockLoader(base.Loader):
"""Loader using make_classification."""
def __init__(self, random_state=0):
"""Initializes Mocking loader."""
self.random_state = random_state
self.x, self.y = make_classification(random_state=self.random_state)
def load(self):
return make_classification(random_state=self.random_state)
"""Load data."""
return self.x, self.y
@property
def configuration(self):
"""Return configuration for db."""
return {
'random_state': self.random_state,
}
class MockTrainTestLoader(MockLoader):
class MockTrainTestLoader(base.TrainTestLoader):
"""TrainTestLoader using make_classification."""
def __init__(self, random_state=0, test_ratio=0.3):
super().__init__(random_state)
"""Initializes Mocking loader."""
self.random_state = random_state
self.test_ratio = test_ratio
self.x, self.y = make_classification(random_state=self.random_state)
self.xtrain, self.xtest, self.ytrain, self.ytest = train_test_split(
self.x,
self.y,
test_size=self.test_ratio,
random_state=self.random_state)
def load(self):
x, y = super().load()
data = train_test_split(x,
y,
test_size=self.test_ratio,
random_state=self.random_state)
xtrain, xtest, ytrain, ytest = data
return (xtrain, ytrain), (xtest, ytest)
def load_train(self):
"""Load training data."""
return self.xtrain, self.ytrain
def load_test(self):
"""Load testing data."""
return self.xtest, self.ytest
@property
def configuration(self):
config = super().configuration
config['test_ratio'] = self.test_ratio
return config
"""Return configuration for db."""
return {
'random_state': self.random_state,
'test_ratio': self.test_ratio,
}
class MockTrainTestValidateLoader(MockTrainTestLoader):
def __init__(self, random_state=0, test_ratio=0.2, valid_ratio=0.2):
super().__init__(random_state, test_ratio)
self.valid_ratio = valid_ratio
class MockTrainTestValidateLoader(base.TrainValidateTestLoader):
"""TrainValidateTestLoader using make_classification."""
def load(self):
(xtrain, ytrain), (xtest, ytest) = super().load()
data = train_test_split(xtrain, ytrain, test_size=self.valid_ratio)
xtrain, xvalid, ytrain, yvalid = data
return (xtrain, ytrain), (xvalid, yvalid), (xtest, ytest)
def __init__(self, random_state=0, test_ratio=0.3):
"""Initializes Mocking loader."""
self.random_state = random_state
self.test_ratio = test_ratio
self.x, self.y = make_classification(random_state=self.random_state)
self.xtrain, self.xtest, self.ytrain, self.ytest = train_test_split(
self.x,
self.y,
test_size=self.test_ratio,
random_state=self.random_state)
self.xtest, self.xval, self.ytest, self.yval = train_test_split(
self.xtest,
self.ytest,
test_size=0.5,
random_state=self.random_state)
def load_train(self):
"""Load training data."""
return self.xtrain, self.ytrain
def load_test(self):
"""Load testing data."""
return self.xtest, self.ytest
def load_validate(self):
"""Load validation data."""
return self.xval, self.yval
@property
def configuration(self):
config = super().configuration
config['validation_ratio'] = self.valid_ratio
return config
"""Return configuration for db."""
return {
'random_state': self.random_state,
'test_ratio': self.test_ratio,
}
class MockMultiLoader(MockLoader):
""" Mocking loader which loads a list of valid runs"""
"""Mocking loader which loads a list of valid runs."""
def __init__(self, run_count, random_state=0):
"""Initializes Mocking loader."""
super().__init__(random_state)
self._run_count = run_count
@property
def run_count(self):
"""Returns the number of runs this loader creates."""
return self._run_count
@property
def configuration(self):
"""Return configuration for db."""
configurations = []
for i in range(self.run_count):
conf = super().configuration
......@@ -102,5 +143,6 @@ class MockMultiLoader(MockLoader):
return configurations
def load(self):
for i in range(self.run_count):
yield (super().load())
"""Load data."""
for _ in range(self.run_count):
yield super().load()
......@@ -5,6 +5,7 @@ import warnings
from dbispipeline.dataloaders.testing_utils import DataloaderUnitTest
from dbispipeline.dataloaders.wrappers import limiting
from .dummy_loader import MockTrainTestValidateLoader
from .dummy_loader import VeryDumbLoader
from .dummy_loader_multi import DummyTrainTestLoader
from .dummy_loader_multi import DummyTrainValidateTestLoader
......@@ -131,3 +132,18 @@ class TestLimitingLoader(DataloaderUnitTest):
'n_samples': 10,
}
self.assertRaises(ValueError, LimitedDumbLoader, **loader_arguments)
def test_limit_training_data_only(self):
"""Only limit the training part of the dataset, not the rest."""
loader = limiting(MockTrainTestValidateLoader)(
test_ratio=0.66,
max_documents_per_target=2,
restrict_training_only=True,
)
# loader = MockTrainTestValidateLoader()
train, test, validate = loader.load()
self.assertLessEqual(len(train[0]), 2 * 2)
self.assertGreaterEqual(len(test[0]), 2 * 2)
self.assertGreaterEqual(len(validate[0]), 2 * 2)
self._loader_sanity_check(*train, *test, *validate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment