From 11833b6fe149bf241a9e693df878eda52b1ad2bf Mon Sep 17 00:00:00 2001 From: konstantin Date: Thu, 2 Feb 2023 10:37:55 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Introduce=20`BlocklistFilter`=20and?= =?UTF-8?q?=20`AllowlistFilter`=20(#29)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/bomf/filter/__init__.py | 46 ++++++++++++++++++++++++++++++++++++- unittests/test_filter.py | 18 ++++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/src/bomf/filter/__init__.py b/src/bomf/filter/__init__.py index 0228351..2b2f623 100644 --- a/src/bomf/filter/__init__.py +++ b/src/bomf/filter/__init__.py @@ -7,7 +7,7 @@ import asyncio import logging from abc import ABC, abstractmethod -from typing import Awaitable, Generic, List, TypeVar +from typing import Awaitable, Callable, Generic, List, Set, TypeVar Candidate = TypeVar("Candidate") #: an arbitrary but fixed type on which the filter operates @@ -99,3 +99,47 @@ async def apply(self, candidates: List[Candidate]) -> List[Candidate]: filtered_aggregates = await self._base_filter.apply(aggregates) self._logger.info("There are %i filtered aggregates left", len(filtered_aggregates)) return [self.disaggregate(fa) for fa in filtered_aggregates] + + +CandidateProperty = TypeVar("CandidateProperty") + + +class HardcodedFilter(Filter[Candidate], ABC, Generic[Candidate, CandidateProperty]): + """ + a harcoded filter filters on a hardcoded list of allowed/blocked values (formerly known as white- and blacklist) + """ + + def __init__(self, criteria_selector: Callable[[Candidate], CandidateProperty], values: Set[CandidateProperty]): + """ + instantiate by providing a criteria selector that returns a property on which we can filter and a set of values. + Whether the values are used as allowed or not allowed (block) depends on the inheriting class + """ + super().__init__() + self._criteria_selector = criteria_selector + self._values = values + + +class BlocklistFilter(HardcodedFilter[Candidate, CandidateProperty]): + """ + remove those candidates whose property is in the provided blocklist + """ + + async def predicate(self, candidate: Candidate) -> bool: + candidate_property: CandidateProperty = self._criteria_selector(candidate) + result = candidate_property not in self._values + if result is False: + self._logger.debug("'%s' is in the blocklist", candidate_property) + return result + + +class AllowlistFilter(HardcodedFilter[Candidate, CandidateProperty]): + """ + let those candidates pass, whose property is in the provided allowlist + """ + + async def predicate(self, candidate: Candidate) -> bool: + candidate_property: CandidateProperty = self._criteria_selector(candidate) + result = candidate_property in self._values + if result is False: + self._logger.debug("'%s' is not in the allowlist", candidate_property) + return result diff --git a/unittests/test_filter.py b/unittests/test_filter.py index 4aeb13c..34480c5 100644 --- a/unittests/test_filter.py +++ b/unittests/test_filter.py @@ -5,7 +5,7 @@ import pytest # type:ignore[import] -from bomf.filter import AggregateFilter, Filter +from bomf.filter import AggregateFilter, AllowlistFilter, BlocklistFilter, Filter from bomf.filter.sourcedataproviderfilter import SourceDataProviderFilter from bomf.provider import ListBasedSourceDataProvider, SourceDataProvider @@ -102,6 +102,22 @@ async def test_aggregate_filter( assert "There are 2 filtered aggregates left" in caplog.messages +class TestBlockAndAllowlistFilter: + async def test_allowlist_filter(self): + allowlist = {"A", "B", "C"} + candidates: List[dict[str, str]] = [{"foo": "A"}, {"foo": "B"}, {"foo": "Z"}] + allowlist_filter: AllowlistFilter[dict[str, str], str] = AllowlistFilter(lambda c: c["foo"], allowlist) + actual = await allowlist_filter.apply(candidates) + assert actual == [{"foo": "A"}, {"foo": "B"}] + + async def test_blocklist_filter(self): + blocklist = {"A", "B", "C"} + candidates: List[dict[str, str]] = [{"foo": "A"}, {"foo": "B"}, {"foo": "Z"}] + blocklist_filter: BlocklistFilter[dict[str, str], str] = BlocklistFilter(lambda c: c["foo"], blocklist) + actual = await blocklist_filter.apply(candidates) + assert actual == [{"foo": "Z"}] + + class TestSourceDataProviderFilter: @pytest.mark.parametrize( "candidate_filter,candidates,survivors",