diff --git a/src/bomf/filter/sourcedataproviderfilter.py b/src/bomf/filter/sourcedataproviderfilter.py new file mode 100644 index 0000000..a96c0f3 --- /dev/null +++ b/src/bomf/filter/sourcedataproviderfilter.py @@ -0,0 +1,70 @@ +""" +Source Data Provider Filters combine the features of a filter with the features of a Source Data Provider. +""" + +from typing import Callable, Generic, List, Literal, Optional, overload + +from bomf import KeyTyp, SourceDataProvider +from bomf.filter import Candidate, Filter +from bomf.provider import JsonFileSourceDataProvider, ListBasedSourceDataProvider + + +# pylint:disable=too-few-public-methods +class SourceDataProviderFilter(Generic[Candidate, KeyTyp]): + """ + a filter that works on and returns a CandidateSourceDataProvider + """ + + def __init__(self, candidate_filter: Filter[Candidate]): + """ + instantiate by providing a filter which can be applied on the data providers source data models + """ + self._filter = candidate_filter + + @overload + async def apply(self, source_data_provider: JsonFileSourceDataProvider) -> SourceDataProvider[Candidate, KeyTyp]: + ... + + @overload + async def apply(self, source_data_provider: ListBasedSourceDataProvider) -> SourceDataProvider[Candidate, KeyTyp]: + ... + + @overload + async def apply( + self, source_data_provider: JsonFileSourceDataProvider, key_selector: Literal[None] + ) -> SourceDataProvider[Candidate, KeyTyp]: + ... + + @overload + async def apply( + self, source_data_provider: ListBasedSourceDataProvider, key_selector: Literal[None] + ) -> SourceDataProvider[Candidate, KeyTyp]: + ... + + async def apply( + self, + source_data_provider: SourceDataProvider[Candidate, KeyTyp], + key_selector: Optional[Callable[[Candidate], KeyTyp]] = None, + ) -> SourceDataProvider[Candidate, KeyTyp]: + """ + Reads all the data from the given source_data_provider, applies the filtering, then returns a new source + data provider that only contains those entries that passed the filter (its predicate). + + If the provided source_data_provider is a JsonFileSourceDataProvider, then you don't have to provide a + key_selector (let it default to None). + However, in general, you have to specify how the data can be indexed using a key_selector which is not None. + If you provide both a JsonFileSourceDataProvider AND a key_selector, the explicit key_selector will be used. + """ + survivors: List[Candidate] = await self._filter.apply(source_data_provider.get_data()) + key_selector_to_be_used: Callable[[Candidate], KeyTyp] + if key_selector is not None: + key_selector_to_be_used = key_selector + else: + key_selector_to_be_used = source_data_provider.key_selector # type:ignore[attr-defined] + # if this raises an attribute error you have to + # * either provide a source_data_provider which has a key_selector attribute + # * or explicitly provide a key_selector as (non-None) argument + filtered_data_provider_class = ListBasedSourceDataProvider( + source_data_models=survivors, key_selector=key_selector_to_be_used + ) + return filtered_data_provider_class diff --git a/src/bomf/provider/__init__.py b/src/bomf/provider/__init__.py index 7a6e9ab..54a0f4f 100644 --- a/src/bomf/provider/__init__.py +++ b/src/bomf/provider/__init__.py @@ -4,7 +4,7 @@ import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Callable, Generic, List, Mapping, Optional, TypeVar, Union +from typing import Callable, Generic, List, Mapping, Optional, Protocol, TypeVar, Union SourceDataModel = TypeVar("SourceDataModel") """ @@ -39,6 +39,26 @@ def get_entry(self, key: KeyTyp) -> SourceDataModel: """ +class ListBasedSourceDataProvider(SourceDataProvider[SourceDataModel, KeyTyp]): + """ + A source data provider that is instantiated with a list of source data models + """ + + def __init__(self, source_data_models: List[SourceDataModel], key_selector: Callable[[SourceDataModel], KeyTyp]): + """ + instantiate it by providing a list of source data models + """ + self._models: List[SourceDataModel] = source_data_models + self._models_dict: Mapping[KeyTyp, SourceDataModel] = {key_selector(m): m for m in source_data_models} + self.key_selector = key_selector + + def get_entry(self, key: KeyTyp) -> SourceDataModel: + return self._models_dict[key] + + def get_data(self) -> List[SourceDataModel]: + return self._models + + class JsonFileSourceDataProvider(SourceDataProvider[SourceDataModel, KeyTyp], Generic[SourceDataModel, KeyTyp]): """ a source data model provider that is based on a JSON file @@ -61,6 +81,7 @@ def __init__( self._key_to_data_model_mapping: Mapping[KeyTyp, SourceDataModel] = { key_selector(sdm): sdm for sdm in self._source_data_models } + self.key_selector = key_selector def get_data(self) -> List[SourceDataModel]: return self._source_data_models diff --git a/unittests/test_filter.py b/unittests/test_filter.py index 639bb70..4aeb13c 100644 --- a/unittests/test_filter.py +++ b/unittests/test_filter.py @@ -6,6 +6,8 @@ import pytest # type:ignore[import] from bomf.filter import AggregateFilter, Filter +from bomf.filter.sourcedataproviderfilter import SourceDataProviderFilter +from bomf.provider import ListBasedSourceDataProvider, SourceDataProvider class _FooFilter(Filter): @@ -98,3 +100,46 @@ async def test_aggregate_filter( assert actual == survivors assert "There are 4 candidates and 4 aggregates" in caplog.messages assert "There are 2 filtered aggregates left" in caplog.messages + + +class TestSourceDataProviderFilter: + @pytest.mark.parametrize( + "candidate_filter,candidates,survivors", + [ + pytest.param( + _BarFilter(), + [ + _MyCandidate(number=1, string="foo"), + _MyCandidate(number=19, string="bar"), + _MyCandidate(number=2, string="foo"), + _MyCandidate(number=17, string="bar"), + ], + [_MyCandidate(number=19, string="bar"), _MyCandidate(number=2, string="foo")], + ), + ], + ) + async def test_source_data_provider_filter( + self, + candidate_filter: Filter[_MyCandidate], + candidates: List[_MyCandidate], + survivors: List[_MyCandidate], + caplog, + ): + my_provider: ListBasedSourceDataProvider[_MyCandidate, int] = ListBasedSourceDataProvider( + candidates, key_selector=lambda mc: mc.number + ) + sdp_filter: SourceDataProviderFilter[_MyCandidate, int] = SourceDataProviderFilter(candidate_filter) + caplog.set_level(logging.DEBUG, logger=self.__module__) + filtered_provider = await sdp_filter.apply(my_provider) + assert isinstance(filtered_provider, SourceDataProvider) + actual = filtered_provider.get_data() + assert actual == survivors + assert "There are 4 candidates and 4 aggregates" in caplog.messages + assert "There are 2 filtered aggregates left" in caplog.messages + + async def test_source_data_provider_filter_error(self): + my_provider = ListBasedSourceDataProvider([{"foo": "bar"}, {"foo": "notbar"}], key_selector=lambda d: d["foo"]) + del my_provider.key_selector + sdp_filter: SourceDataProviderFilter[_MyCandidate, int] = SourceDataProviderFilter(_FooFilter()) + with pytest.raises(AttributeError): + await sdp_filter.apply(my_provider) diff --git a/unittests/test_source_data_provider.py b/unittests/test_source_data_provider.py index 904c26e..0cfbd8c 100644 --- a/unittests/test_source_data_provider.py +++ b/unittests/test_source_data_provider.py @@ -1,9 +1,9 @@ from pathlib import Path -from typing import List, Optional +from typing import List import pytest # type:ignore[import] -from bomf.provider import JsonFileSourceDataProvider, KeyTyp, SourceDataProvider +from bomf.provider import JsonFileSourceDataProvider, KeyTyp, ListBasedSourceDataProvider, SourceDataProvider class LegacyDataSystemDataProvider(SourceDataProvider): @@ -39,3 +39,10 @@ def test_json_file_provider(self, datafiles): assert example_json_data_provider.get_entry("world") == {"myKey": "world", "qwe": "rtz"} with pytest.raises(KeyError): _ = example_json_data_provider.get_entry("something unknown") + + +class TestListBasedSourceDataProvider: + def test_list_based_provider(self): + my_provider = ListBasedSourceDataProvider(["foo", "bar", "baz"], key_selector=lambda x: x) + assert len(my_provider.get_data()) == 3 + assert my_provider.get_entry("bar") == "bar"