Skip to content

Commit

Permalink
✨Introduce SourceDataProviderFilter (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
hf-kklein authored Feb 1, 2023
1 parent 0754db1 commit 5759cd1
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 3 deletions.
70 changes: 70 additions & 0 deletions src/bomf/filter/sourcedataproviderfilter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Source Data Provider Filters combine the features of a filter with the features of a Source Data Provider.
"""

from typing import Callable, Generic, List, Literal, Optional, overload

from bomf import KeyTyp, SourceDataProvider
from bomf.filter import Candidate, Filter
from bomf.provider import JsonFileSourceDataProvider, ListBasedSourceDataProvider


# pylint:disable=too-few-public-methods
class SourceDataProviderFilter(Generic[Candidate, KeyTyp]):
"""
a filter that works on and returns a CandidateSourceDataProvider
"""

def __init__(self, candidate_filter: Filter[Candidate]):
"""
instantiate by providing a filter which can be applied on the data providers source data models
"""
self._filter = candidate_filter

@overload
async def apply(self, source_data_provider: JsonFileSourceDataProvider) -> SourceDataProvider[Candidate, KeyTyp]:
...

@overload
async def apply(self, source_data_provider: ListBasedSourceDataProvider) -> SourceDataProvider[Candidate, KeyTyp]:
...

@overload
async def apply(
self, source_data_provider: JsonFileSourceDataProvider, key_selector: Literal[None]
) -> SourceDataProvider[Candidate, KeyTyp]:
...

@overload
async def apply(
self, source_data_provider: ListBasedSourceDataProvider, key_selector: Literal[None]
) -> SourceDataProvider[Candidate, KeyTyp]:
...

async def apply(
self,
source_data_provider: SourceDataProvider[Candidate, KeyTyp],
key_selector: Optional[Callable[[Candidate], KeyTyp]] = None,
) -> SourceDataProvider[Candidate, KeyTyp]:
"""
Reads all the data from the given source_data_provider, applies the filtering, then returns a new source
data provider that only contains those entries that passed the filter (its predicate).
If the provided source_data_provider is a JsonFileSourceDataProvider, then you don't have to provide a
key_selector (let it default to None).
However, in general, you have to specify how the data can be indexed using a key_selector which is not None.
If you provide both a JsonFileSourceDataProvider AND a key_selector, the explicit key_selector will be used.
"""
survivors: List[Candidate] = await self._filter.apply(source_data_provider.get_data())
key_selector_to_be_used: Callable[[Candidate], KeyTyp]
if key_selector is not None:
key_selector_to_be_used = key_selector
else:
key_selector_to_be_used = source_data_provider.key_selector # type:ignore[attr-defined]
# if this raises an attribute error you have to
# * either provide a source_data_provider which has a key_selector attribute
# * or explicitly provide a key_selector as (non-None) argument
filtered_data_provider_class = ListBasedSourceDataProvider(
source_data_models=survivors, key_selector=key_selector_to_be_used
)
return filtered_data_provider_class
23 changes: 22 additions & 1 deletion src/bomf/provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Generic, List, Mapping, Optional, TypeVar, Union
from typing import Callable, Generic, List, Mapping, Optional, Protocol, TypeVar, Union

SourceDataModel = TypeVar("SourceDataModel")
"""
Expand Down Expand Up @@ -39,6 +39,26 @@ def get_entry(self, key: KeyTyp) -> SourceDataModel:
"""


class ListBasedSourceDataProvider(SourceDataProvider[SourceDataModel, KeyTyp]):
"""
A source data provider that is instantiated with a list of source data models
"""

def __init__(self, source_data_models: List[SourceDataModel], key_selector: Callable[[SourceDataModel], KeyTyp]):
"""
instantiate it by providing a list of source data models
"""
self._models: List[SourceDataModel] = source_data_models
self._models_dict: Mapping[KeyTyp, SourceDataModel] = {key_selector(m): m for m in source_data_models}
self.key_selector = key_selector

def get_entry(self, key: KeyTyp) -> SourceDataModel:
return self._models_dict[key]

def get_data(self) -> List[SourceDataModel]:
return self._models


class JsonFileSourceDataProvider(SourceDataProvider[SourceDataModel, KeyTyp], Generic[SourceDataModel, KeyTyp]):
"""
a source data model provider that is based on a JSON file
Expand All @@ -61,6 +81,7 @@ def __init__(
self._key_to_data_model_mapping: Mapping[KeyTyp, SourceDataModel] = {
key_selector(sdm): sdm for sdm in self._source_data_models
}
self.key_selector = key_selector

def get_data(self) -> List[SourceDataModel]:
return self._source_data_models
Expand Down
45 changes: 45 additions & 0 deletions unittests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pytest # type:ignore[import]

from bomf.filter import AggregateFilter, Filter
from bomf.filter.sourcedataproviderfilter import SourceDataProviderFilter
from bomf.provider import ListBasedSourceDataProvider, SourceDataProvider


class _FooFilter(Filter):
Expand Down Expand Up @@ -98,3 +100,46 @@ async def test_aggregate_filter(
assert actual == survivors
assert "There are 4 candidates and 4 aggregates" in caplog.messages
assert "There are 2 filtered aggregates left" in caplog.messages


class TestSourceDataProviderFilter:
@pytest.mark.parametrize(
"candidate_filter,candidates,survivors",
[
pytest.param(
_BarFilter(),
[
_MyCandidate(number=1, string="foo"),
_MyCandidate(number=19, string="bar"),
_MyCandidate(number=2, string="foo"),
_MyCandidate(number=17, string="bar"),
],
[_MyCandidate(number=19, string="bar"), _MyCandidate(number=2, string="foo")],
),
],
)
async def test_source_data_provider_filter(
self,
candidate_filter: Filter[_MyCandidate],
candidates: List[_MyCandidate],
survivors: List[_MyCandidate],
caplog,
):
my_provider: ListBasedSourceDataProvider[_MyCandidate, int] = ListBasedSourceDataProvider(
candidates, key_selector=lambda mc: mc.number
)
sdp_filter: SourceDataProviderFilter[_MyCandidate, int] = SourceDataProviderFilter(candidate_filter)
caplog.set_level(logging.DEBUG, logger=self.__module__)
filtered_provider = await sdp_filter.apply(my_provider)
assert isinstance(filtered_provider, SourceDataProvider)
actual = filtered_provider.get_data()
assert actual == survivors
assert "There are 4 candidates and 4 aggregates" in caplog.messages
assert "There are 2 filtered aggregates left" in caplog.messages

async def test_source_data_provider_filter_error(self):
my_provider = ListBasedSourceDataProvider([{"foo": "bar"}, {"foo": "notbar"}], key_selector=lambda d: d["foo"])
del my_provider.key_selector
sdp_filter: SourceDataProviderFilter[_MyCandidate, int] = SourceDataProviderFilter(_FooFilter())
with pytest.raises(AttributeError):
await sdp_filter.apply(my_provider)
11 changes: 9 additions & 2 deletions unittests/test_source_data_provider.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from pathlib import Path
from typing import List, Optional
from typing import List

import pytest # type:ignore[import]

from bomf.provider import JsonFileSourceDataProvider, KeyTyp, SourceDataProvider
from bomf.provider import JsonFileSourceDataProvider, KeyTyp, ListBasedSourceDataProvider, SourceDataProvider


class LegacyDataSystemDataProvider(SourceDataProvider):
Expand Down Expand Up @@ -39,3 +39,10 @@ def test_json_file_provider(self, datafiles):
assert example_json_data_provider.get_entry("world") == {"myKey": "world", "qwe": "rtz"}
with pytest.raises(KeyError):
_ = example_json_data_provider.get_entry("something unknown")


class TestListBasedSourceDataProvider:
def test_list_based_provider(self):
my_provider = ListBasedSourceDataProvider(["foo", "bar", "baz"], key_selector=lambda x: x)
assert len(my_provider.get_data()) == 3
assert my_provider.get_entry("bar") == "bar"

0 comments on commit 5759cd1

Please sign in to comment.