Commit dcb84b6a authored by Benjamin Murauer's avatar Benjamin Murauer
Browse files

Merge branch '56-make-multiloaders-nestable' into 'master'

Resolve "Make Multiloaders nestable"

Closes #56

See merge request dbis/software/dbispipeline!37
parents 6fca1257 e0cfc01a
Pipeline #41304 passed with stage
in 2 minutes and 4 seconds
This diff is collapsed.
......@@ -20,12 +20,12 @@ class Loader(ABC):
@property
def run_count(self):
"""Returns how many configurations this dataloader will produce."""
return 0
return 1
@property
def is_multiloader(self):
"""Returns if this dataloader will produce multiple configurations."""
return self.run_count > 0
return self.run_count > 1
class TrainTestLoader(Loader):
......
......@@ -11,44 +11,6 @@ from dbispipeline.base import Loader
from dbispipeline.base import TrainTestLoader
class RepeatingLoader(Loader):
"""Repeats a specific loader a certain number of times."""
def __init__(self, loader_class, loader_kwargs, repetitions):
"""
Initializes the loader.
Args:
loader_class: Loader
the class of the loader to be initialized
loader_kwargs:
the keyword arguments used to initialize the loader
repetitions:
how many runs should be produced
"""
self.loader = loader_class(**loader_kwargs)
self.repetitions = repetitions
def load(self):
"""Loads the data."""
for _ in range(self.repetitions):
yield self.loader.load()
@property
def run_count(self):
"""Returns how many runs this loader will produce."""
return self.repetitions
@property
def configuration(self):
"""DB-Representation of this loader."""
config = self.loader.configuration
config.update({
'repetitions': self.repetitions,
})
return config
class MultiLoaderWrapper(Loader):
"""Simple wrapper sequentially yielding the result of a list of loaders."""
......@@ -64,25 +26,62 @@ class MultiLoaderWrapper(Loader):
def load(self):
"""Loads the data."""
for loader in self.loaders:
yield loader.load()
if loader.is_multiloader:
for run in loader.load():
yield run
else:
yield loader.load()
@property
def run_count(self):
"""Returns how many runs this loader will produce."""
return len(self.loaders)
return sum([loader.run_count for loader in self.loaders])
@property
def configuration(self):
"""DB-Representation of this loader."""
for i, loader in enumerate(self.loaders):
yield dict(
**loader.configuration,
run_number=i,
loader_class=loader.__class__.__name__,
)
i = 0
def prepare_config(config):
# note that this might overwrite a nested multi loader
config['run_number'] = i
if 'loader_class' in config:
config[f'loader_class_{i}'] = config['loader_class']
config['loader_class'] = loader.__class__.__name__
return config
for loader in self.loaders:
if loader.is_multiloader:
for sub_config in loader.configuration:
yield prepare_config(sub_config)
i += 1
else:
yield prepare_config(loader.configuration)
i += 1
class RepeatingLoader(MultiLoaderWrapper):
"""Repeats a specific loader a certain number of times."""
def __init__(self, loader_class, loader_kwargs, repetitions):
"""
Initializes the loader.
Args:
loader_class: Loader
the class of the loader to be initialized
loader_kwargs:
the keyword arguments used to initialize the loader
repetitions:
how many runs should be produced
"""
loaders = []
for _ in range(repetitions):
loaders.append(loader_class(**loader_kwargs))
super().__init__(loaders)
class MultiLoaderGenerator(Loader):
class MultiLoaderGenerator(MultiLoaderWrapper):
"""Produces a MultiLoader by specifying a range of possible parameters."""
def __init__(self, loader_class, parameters):
......@@ -99,36 +98,18 @@ class MultiLoaderGenerator(Loader):
If passed a dict, a grid of all combinations is generated and
passed to the loader.
"""
self.loaders = []
loaders = []
if isinstance(parameters, dict):
for sample in ParameterGrid(parameters):
# this produces only dicts
self.loaders.append(loader_class(**sample))
loaders.append(loader_class(**sample))
else:
for sample in parameters:
if isinstance(sample, dict):
self.loaders.append(loader_class(**sample))
loaders.append(loader_class(**sample))
else:
self.loaders.append(loader_class(*sample))
def load(self):
"""Loads the data."""
for loader in self.loaders:
yield loader.load()
@property
def configuration(self):
"""DB-Representation of this loader."""
for i, loader in enumerate(self.loaders):
config = loader.configuration
config['run_number'] = i
config['class'] = loader.__class__.__name__
yield config
@property
def run_count(self):
"""Returns how many runs this loader will produce."""
return len(self.loaders)
loaders.append(loader_class(*sample))
super().__init__(loaders)
def label_encode(cls):
......
"""Tests wrappers that produce multi loaders."""
import unittest
from dbispipeline.base import Loader
......@@ -5,25 +7,31 @@ from dbispipeline.dataloaders.wrappers import MultiLoaderGenerator
class TinyTestLoader(Loader):
"""Dummy test loader that has two arbitrary parameters."""
def __init__(self, parameter1, parameter2):
"""Initialize dummy loader."""
self.parameter1 = parameter1
self.parameter2 = parameter2
@property
def configuration(self):
"""Return db representation of this loader."""
return {
'parameter1': self.parameter1,
'parameter2': self.parameter2,
}
def load(self):
"""Loads dummy data."""
return self.parameter1, self.parameter2
class TestMultiLoaderGenerator(unittest.TestCase):
"""Test all variants of MultiLoaderGeneator."""
def test_explicit_tuples(self):
"""Test passing an explicit list of parameters as tuples."""
parameters = [
(1, 'a'),
(1, 'b'),
......@@ -34,48 +42,46 @@ class TestMultiLoaderGenerator(unittest.TestCase):
self.assertEqual(dataloader.run_count, 4)
configs = list(dataloader.configuration)
self.assertEqual(len(configs), 4)
self.assertTrue({
config = {
'parameter1': 2,
'parameter2': 'a',
'run_number': 2,
'class': 'TinyTestLoader'
} in configs)
'loader_class': 'TinyTestLoader',
}
self.assertTrue(config in configs)
data = list(dataloader.load())
self.assertEqual(len(data), 4),
self.assertEqual(len(data), 4)
self.assertTrue((1, 'b') in data)
self.assertTrue((2, 'a') in data)
self.assertTrue((3, 'b') in data)
self.assertFalse((5, 'b') in data)
def test_explicit_dicts(self):
"""
tests the generation of a multiloader by using explicit parameters
"""
"""Test passing an explicit list of parameters as kwargs-dicts."""
parameters = [
{
'parameter1': 1,
'parameter2': 'a'
'parameter2': 'a',
},
{
'parameter1': 1,
'parameter2': 'b'
'parameter2': 'b',
},
{
'parameter1': 2,
'parameter2': 'a'
'parameter2': 'a',
},
{
'parameter1': 2,
'parameter2': 'b'
'parameter2': 'b',
},
{
'parameter1': 3,
'parameter2': 'a'
'parameter2': 'a',
},
{
'parameter1': 3,
'parameter2': 'b'
'parameter2': 'b',
},
]
......@@ -84,24 +90,22 @@ class TestMultiLoaderGenerator(unittest.TestCase):
self.assertEqual(dataloader.run_count, 6)
configs = list(dataloader.configuration)
self.assertEqual(len(configs), 6)
self.assertTrue({
config = {
'parameter1': 2,
'parameter2': 'b',
'class': 'TinyTestLoader',
'run_number': 3
} in configs)
'loader_class': 'TinyTestLoader',
'run_number': 3,
}
self.assertTrue(config in configs)
data = list(dataloader.load())
self.assertEqual(len(data), 6),
self.assertEqual(len(data), 6)
self.assertTrue((1, 'b') in data)
self.assertTrue((2, 'b') in data)
self.assertTrue((3, 'b') in data)
self.assertFalse((5, 'b') in data)
def test_generated_dictionary(self):
"""
tests the generation of a multiloader by using a parameter dict
"""
"""Test passing an dict of possible parameter values."""
parameters = {
'parameter1': [1, 2, 3],
'parameter2': ['a', 'b'],
......@@ -118,8 +122,49 @@ class TestMultiLoaderGenerator(unittest.TestCase):
configs))
self.assertEqual(len(found), 1)
data = list(dataloader.load())
self.assertEqual(len(data), 6),
self.assertEqual(len(data), 6)
self.assertTrue((1, 'b') in data)
self.assertTrue((2, 'b') in data)
self.assertTrue((3, 'b') in data)
self.assertFalse((5, 'b') in data)
def test_nested_multiloaders(self):
"""Test nesting a MultiLoaderGenerator."""
params_as_tuple_list = [
('a', 1),
('a', 2),
('b', 1),
('b', 2),
]
params_as_dict_list = [
dict(parameter1='a', parameter2=1),
dict(parameter1='a', parameter2=2),
dict(parameter1='b', parameter2=1),
dict(parameter1='b', parameter2=2),
]
parameters_as_possible_values_dict = dict(
parameter1=['a', 'b'],
parameter2=[1, 2],
)
outer_loader = MultiLoaderGenerator(
loader_class=MultiLoaderGenerator,
parameters=dict(
loader_class=[TinyTestLoader],
parameters=[
params_as_tuple_list,
params_as_dict_list,
parameters_as_possible_values_dict,
],
),
)
expected_runs = 12
self.assertEqual(outer_loader.run_count, expected_runs)
data = list(outer_loader.load())
self.assertEqual(len(data), expected_runs)
configs = list(outer_loader.configuration)
self.assertEqual(len(configs), expected_runs)
......@@ -35,6 +35,33 @@ class TestRepeatingLoader(unittest.TestCase):
params = {'parameter1': 1, 'parameter2': 2}
loader = RepeatingLoader(TinyTestLoader, params, n)
self.assertEqual(loader.run_count, n)
self.assertEqual(loader.configuration['repetitions'], n)
for data in loader.load():
self.assertEqual(data, tuple(params.values()))
def test_nested_multiloaders_with_conflicting_names(self):
"""Test correct configuration of nested loaders."""
loader = RepeatingLoader(
loader_class=RepeatingLoader,
loader_kwargs=dict(
loader_class=TinyTestLoader,
loader_kwargs=dict(
parameter1='a',
parameter2='b',
),
repetitions=2,
),
repetitions=5,
)
self.assertEqual(loader.run_count, 10)
configs = list(loader.configuration)
for i in range(10):
config = {
f'loader_class_{i}': 'TinyTestLoader',
'loader_class': 'RepeatingLoader',
'run_number': i,
'parameter1': 'a',
'parameter2': 'b',
}
self.assertTrue(config in configs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment