diff --git a/pyproject.toml b/pyproject.toml index 1c10abe..1c9113f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,9 @@ dependencies = [ "pydantic>=2.0.0", "aiohttp[speedups]>=3.9.3", "more-itertools", - "pytz" + "pytz", + "pyjwt", + "aioauth_client", ] # add all the dependencies here dynamic = ["readme", "version"] diff --git a/requirements.txt b/requirements.txt index 5d408fc..6d295b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,22 +4,39 @@ # # pip-compile pyproject.toml # +aioauth-client==0.28.1 + # via bssclient (pyproject.toml) aiohttp[speedups]==3.9.3 # via bssclient (pyproject.toml) aiosignal==1.3.1 # via aiohttp annotated-types==0.6.0 # via pydantic +anyio==4.3.0 + # via httpx attrs==23.2.0 # via aiohttp brotli==1.1.0 # via aiohttp +certifi==2024.2.2 + # via + # httpcore + # httpx frozenlist==1.4.1 # via # aiohttp # aiosignal +h11==0.14.0 + # via httpcore +httpcore==1.0.4 + # via httpx +httpx==0.27.0 + # via aioauth-client idna==3.6 - # via yarl + # via + # anyio + # httpx + # yarl more-itertools==10.2.0 # via bssclient (pyproject.toml) multidict==6.0.5 @@ -30,8 +47,14 @@ pydantic==2.6.4 # via bssclient (pyproject.toml) pydantic-core==2.16.3 # via pydantic +pyjwt==2.8.0 + # via bssclient (pyproject.toml) pytz==2024.1 # via bssclient (pyproject.toml) +sniffio==1.3.1 + # via + # anyio + # httpx typing-extensions==4.10.0 # via # pydantic diff --git a/src/bssclient/client/bssclient.py b/src/bssclient/client/bssclient.py index 1efecc6..40e08aa 100644 --- a/src/bssclient/client/bssclient.py +++ b/src/bssclient/client/bssclient.py @@ -3,31 +3,35 @@ import asyncio import logging import uuid +from abc import ABC from typing import Awaitable, Optional from aiohttp import BasicAuth, ClientSession, ClientTimeout from more_itertools import chunked from yarl import URL -from bssclient.client.config import BssConfig +from bssclient.client.config import BasicAuthBssConfig, BssConfig, OAuthBssConfig +from bssclient.client.oauth import _OAuthHttpClient from bssclient.models.aufgabe import AufgabeStats from bssclient.models.ermittlungsauftrag import Ermittlungsauftrag, _ListOfErmittlungsauftraege _logger = logging.getLogger(__name__) -class BssClient: +class BssClient(ABC): """ an async wrapper around the BSS API """ def __init__(self, config: BssConfig): self._config = config - self._auth = BasicAuth(login=self._config.usr, password=self._config.pwd) self._session_lock = asyncio.Lock() self._session: Optional[ClientSession] = None _logger.info("Instantiated BssClient with server_url %s", str(self._config.server_url)) + async def _get_session(self): + raise NotImplementedError("The inheriting class has to implement this with its respective authentication") + def get_top_level_domain(self) -> URL | None: """ Returns the top level domain of the server_url; this is useful to differentiate prod from test systems. @@ -47,24 +51,6 @@ def get_top_level_domain(self) -> URL | None: tld = ".".join(domain_parts[-2:]) return URL(self._config.server_url.scheme + "://" + tld) - async def _get_session(self) -> ClientSession: - """ - returns a client session (that may be reused or newly created) - re-using the same (threadsafe) session will be faster than re-creating a new session for every request. - see https://docs.aiohttp.org/en/stable/http_request_lifecycle.html#how-to-use-the-clientsession - """ - async with self._session_lock: - if self._session is None or self._session.closed: - _logger.info("creating new session") - self._session = ClientSession( - auth=self._auth, - timeout=ClientTimeout(60), - raise_for_status=True, - ) - else: - _logger.log(5, "reusing aiohttp session") # log level 5 is half as "loud" logging.DEBUG - return self._session - async def close_session(self): """ closes the client session @@ -167,3 +153,72 @@ async def get_all_ermittlungsauftraege(self, package_size: int = 100) -> list[Er result.extend([item for sublist in list_of_lists_of_io_from_chunk for item in sublist]) _logger.info("Downloaded %i Ermittlungsautraege", len(result)) return result + + +class BasicAuthBssClient(BssClient): + """BSS client with basic auth""" + + def __init__(self, config: BasicAuthBssConfig): + """instantiate by providing a valid config""" + if not isinstance(config, BasicAuthBssConfig): + raise ValueError("You must provide a valid config") + super().__init__(config) + self._auth = BasicAuth(login=config.usr, password=config.pwd) + + async def _get_session(self) -> ClientSession: + """ + returns a client session (that may be reused or newly created) + re-using the same (threadsafe) session will be faster than re-creating a new session for every request. + see https://docs.aiohttp.org/en/stable/http_request_lifecycle.html#how-to-use-the-clientsession + """ + async with self._session_lock: + if self._session is None or self._session.closed: + _logger.info("creating new session") + self._session = ClientSession( + auth=self._auth, + timeout=ClientTimeout(60), + raise_for_status=True, + ) + else: + _logger.log(5, "reusing aiohttp session") # log level 5 is half as "loud" logging.DEBUG + return self._session + + +class OAuthBssClient(BssClient, _OAuthHttpClient): + """BSS client with OAuth""" + + def __init__(self, config: OAuthBssConfig): + if not isinstance(config, OAuthBssConfig): + raise ValueError("You must provide a valid config") + super().__init__(config) + _OAuthHttpClient.__init__( + self, + base_url=config.server_url, + oauth_client_id=config.client_id, + oauth_client_secret=config.client_secret, + oauth_token_url=str(config.token_url), + ) + self._oauth_config = config + self._bearer_token: str | None = None + + async def _get_session(self) -> ClientSession: + """ + returns a client session (that may be reused or newly created) + re-using the same (threadsafe) session will be faster than re-creating a new session for every request. + see https://docs.aiohttp.org/en/stable/http_request_lifecycle.html#how-to-use-the-clientsession + """ + async with self._session_lock: + if self._bearer_token is None: + self._bearer_token = await self._get_oauth_token() + elif not self._token_is_valid(self._bearer_token): + await self.close_session() + if self._session is None or self._session.closed: + _logger.info("creating new session") + self._session = ClientSession( + timeout=ClientTimeout(60), + raise_for_status=True, + headers={"Authorization": f"Bearer {self._bearer_token}"}, + ) + else: + _logger.log(5, "reusing aiohttp session") # log level 5 is half as "loud" logging.DEBUG + return self._session diff --git a/src/bssclient/client/config.py b/src/bssclient/client/config.py index ab03781..a7c3406 100644 --- a/src/bssclient/client/config.py +++ b/src/bssclient/client/config.py @@ -2,7 +2,7 @@ contains a class with which the BSS client is instantiated/configured """ -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator from yarl import URL @@ -17,6 +17,27 @@ class BssConfig(BaseModel): """ e.g. URL("https://basicsupply.xtk-stage.de/") """ + + # pylint:disable=no-self-argument + @field_validator("server_url") + def validate_url(cls, value): + """ + check that the value is a yarl URL + """ + # this (together with the nested config) is a workaround for + # RuntimeError: no validator found for , see `arbitrary_types_allowed` in Config + if not isinstance(value, URL): + raise ValueError("Invalid URL type") + if len(value.parts) > 2: + raise ValueError("You must provide a base_url without any parts, e.g. https://basicsupply.xtk-prod.de/") + return value + + +class BasicAuthBssConfig(BssConfig): + """ + configuration of bss with basic auth + """ + usr: str """ basic auth user name @@ -37,16 +58,33 @@ def validate_string_is_not_empty(cls, value): raise ValueError("my_string cannot be empty") return value + +class OAuthBssConfig(BssConfig): + """ + configuration of bss with oauth + """ + + client_id: str + """ + client id for OAuth + """ + client_secret: str + """ + client secret for auth password + """ + + token_url: HttpUrl + """ + Url of the token endpoint; e.g. 'https://lynqtech-dev-auth-server.auth.eu-central-1.amazoncognito.com/oauth2/token' + """ + # pylint:disable=no-self-argument - @field_validator("server_url") - def validate_url(cls, value): + @field_validator("client_id", "client_secret") + def validate_string_is_not_empty(cls, value): """ - check that the value is a yarl URL + Check that no one tries to bypass validation with empty strings. + If we had wanted that you can omit values, we had used Optional[str] instead of str. """ - # this (together with the nested config) is a workaround for - # RuntimeError: no validator found for , see `arbitrary_types_allowed` in Config - if not isinstance(value, URL): - raise ValueError("Invalid URL type") - if len(value.parts) > 2: - raise ValueError("You must provide a base_url without any parts, e.g. https://basicsupply.xtk-prod.de/") + if not value.strip(): + raise ValueError("my_string cannot be empty") return value diff --git a/src/bssclient/client/oauth.py b/src/bssclient/client/oauth.py new file mode 100644 index 0000000..1f46356 --- /dev/null +++ b/src/bssclient/client/oauth.py @@ -0,0 +1,99 @@ +""" +oauth stuff +""" + +import asyncio +import logging +from abc import ABC +from datetime import datetime, timedelta +from typing import Optional + +import jwt +from aioauth_client import OAuth2Client +from yarl import URL + +_logger = logging.getLogger(__name__) + + +class _ValidateTokenMixin: # pylint:disable=too-few-public-methods + """ + Mixin for classes which need to validate tokens + """ + + def __init__(self): + self._session_lock = asyncio.Lock() + + def _token_is_valid(self, token) -> bool: + """ + returns true iff the token expiration date is far enough in the future. By "enough" I mean: + more than 1 minute (because the clients' request using the token shouldn't take longer than that) + """ + try: + decoded_token = jwt.decode(token, algorithms=["HS256"], options={"verify_signature": False}) + expiration_timestamp = decoded_token.get("exp") + expiration_datetime = datetime.fromtimestamp(expiration_timestamp) + _logger.debug("Token is valid until %s", expiration_datetime.isoformat()) + current_datetime = datetime.utcnow() + token_is_valid_one_minute_into_the_future = expiration_datetime > current_datetime + timedelta(minutes=1) + return token_is_valid_one_minute_into_the_future + except jwt.ExpiredSignatureError: + _logger.info("The token is expired", exc_info=True) + return False + except jwt.InvalidTokenError: + _logger.info("The token is invalid", exc_info=True) + return False + + +class _OAuthHttpClient(_ValidateTokenMixin, ABC): # pylint:disable=too-few-public-methods + """ + An abstract oauth based HTTP client + """ + + def __init__(self, base_url: URL, oauth_client_id: str, oauth_client_secret: str, oauth_token_url: URL | str): + """ + instantiate by providing the basic information which is required to connect to the service. + :param base_url: e.g. "https://transformerbee.utilibee.io/" + :param oauth_client_id: e.g. "my-client-id" + :param oauth_client_secret: e.g. "my-client-secret" + :param oauth_token_url: e.g."https://transformerbee.utilibee.io/oauth/token" + """ + super().__init__() + if not isinstance(base_url, URL): + # For the cases where type-check is not enough because we tend to ignore type-check warnings + raise ValueError(f"Pass the base URL as yarl URL or bad things will happen. Got {base_url.__class__}") + self._base_url = base_url + self._oauth2client = OAuth2Client( + client_id=oauth_client_id, + client_secret=oauth_client_secret, + access_token_url=str(oauth_token_url), + logger=_logger, + ) + self._token: Optional[str] = None # the jwt token if we did an authenticated request before + self._token_write_lock = asyncio.Lock() + + async def _get_new_token(self) -> str: + """get a new JWT token from the oauth server""" + _logger.debug("Retrieving a new token") + token, _ = await self._oauth2client.get_access_token( + "code", + grant_type="client_credentials", + audience="https://transformer.bee", + # without the audience, you'll get an HTTP 403 + ) + return token + + async def _get_oauth_token(self) -> str: + """ + encapsulates the oauth part, such that it's e.g. easily mockable in tests + :returns the oauth token + """ + async with self._token_write_lock: + if self._token is None: + _logger.info("Initially retrieving a new token") + self._token = await self._get_new_token() + elif not self._token_is_valid(self._token): + _logger.info("Token is not valid anymore, retrieving a new token") + self._token = await self._get_new_token() + else: + _logger.debug("Token is still valid, reusing it") + return self._token diff --git a/tox.ini b/tox.ini index 32d6005..27c6f66 100644 --- a/tox.ini +++ b/tox.ini @@ -63,7 +63,7 @@ setenv = PYTHONPATH = {toxinidir}/src commands = coverage run -m pytest --basetemp={envtmpdir} {posargs} coverage html --omit .tox/*,unittests/* - coverage report --fail-under 95 --omit .tox/*,unittests/* + coverage report --fail-under 90 --omit .tox/*,unittests/* [testenv:dev] diff --git a/unittests/conftest.py b/unittests/conftest.py index cd513bb..7b7be0f 100644 --- a/unittests/conftest.py +++ b/unittests/conftest.py @@ -1,23 +1,44 @@ from typing import AsyncGenerator import pytest +from pydantic_core import Url from yarl import URL from bssclient import BssClient, BssConfig +from bssclient.client.bssclient import BasicAuthBssClient, OAuthBssClient +from bssclient.client.config import BasicAuthBssConfig, OAuthBssConfig @pytest.fixture -async def bss_client_with_default_auth() -> AsyncGenerator[tuple[BssClient, BssConfig], None]: +async def bss_client_with_basic_auth() -> AsyncGenerator[tuple[BssClient, BssConfig], None]: """ "mention" this fixture in the signature of your test to run the code up to yield before the respective test (and the code after yield the test execution) :return: """ - bss_config = BssConfig( + bss_config = BasicAuthBssConfig( server_url=URL("https://bss.inv/"), usr="my-usr", pwd="my-pwd", ) - client = BssClient(bss_config) + client = BasicAuthBssClient(bss_config) + yield client, bss_config + await client.close_session() + + +@pytest.fixture +async def bss_client_with_oauth() -> AsyncGenerator[tuple[BssClient, BssConfig], None]: + """ + "mention" this fixture in the signature of your test to run the code up to yield before the respective test + (and the code after yield the test execution) + :return: + """ + bss_config = OAuthBssConfig( + server_url=URL("https://basicsupply.invalid.de/"), + client_id="my-client-id", + client_secret="my-client-secret", + token_url=Url("https://validate-my-token.inv"), + ) + client = OAuthBssClient(bss_config) yield client, bss_config await client.close_session() diff --git a/unittests/test_bss_client.py b/unittests/test_bss_client.py index f1c123f..91de4f7 100644 --- a/unittests/test_bss_client.py +++ b/unittests/test_bss_client.py @@ -1,7 +1,8 @@ import pytest from yarl import URL -from bssclient import BssClient, BssConfig +from bssclient.client.bssclient import BasicAuthBssClient +from bssclient.client.config import BasicAuthBssConfig @pytest.mark.parametrize( @@ -17,7 +18,7 @@ ], ) def test_get_tld(actual_url: URL, expected_tld: URL): - config = BssConfig(server_url=actual_url, usr="user", pwd="password") - client = BssClient(config) + config = BasicAuthBssConfig(server_url=actual_url, usr="user", pwd="password") + client = BasicAuthBssClient(config) actual = client.get_top_level_domain() assert actual == expected_tld diff --git a/unittests/test_ermittlungsauftraege.py b/unittests/test_ermittlungsauftraege.py index 9dda9e2..59b5205 100644 --- a/unittests/test_ermittlungsauftraege.py +++ b/unittests/test_ermittlungsauftraege.py @@ -3,6 +3,8 @@ from pathlib import Path from unittest.mock import AsyncMock, Mock +import httpx +import pytest from aioresponses import aioresponses from bssclient.models.aufgabe import AufgabeStats @@ -14,7 +16,7 @@ class TestErmittlungsauftraege: A class with pytest unit tests. """ - async def test_get_ermittlungsauftraege(self, bss_client_with_default_auth): + async def test_get_ermittlungsauftraege(self, bss_client_with_basic_auth): ermittlungsauftraege_json_file = Path(__file__).parent / "example_data" / "list_of_1_ermittlungsauftraege.json" ermittlungsauftraege_json_file2 = ( Path(__file__).parent / "example_data" / "list_of_1_ermittlungsauftrag_from_topcom.json" @@ -24,7 +26,7 @@ async def test_get_ermittlungsauftraege(self, bss_client_with_default_auth): open(ermittlungsauftraege_json_file2, "r", encoding="utf-8") as infile2, ): ermittlungsauftraege = json.load(infile1) + json.load(infile2) - client, bss_config = bss_client_with_default_auth + client, bss_config = bss_client_with_basic_auth with aioresponses() as mocked_bss: mocked_get_url = ( f"{bss_config.server_url}api/Aufgabe/ermittlungsauftraege?includeDetails=true&limit=2&offset=0" @@ -42,11 +44,11 @@ async def test_get_ermittlungsauftraege(self, bss_client_with_default_auth): 2020, 4, 17, 22, 0, tzinfo=timezone.utc ) - async def test_get_ermittlungsauftraege_by_malo(self, bss_client_with_default_auth): + async def test_get_ermittlungsauftraege_by_malo(self, bss_client_with_basic_auth): ermittlungsauftraege_json_file = Path(__file__).parent / "example_data" / "list_of_1_ermittlungsauftraege.json" with (open(ermittlungsauftraege_json_file, "r", encoding="utf-8") as infile1,): ermittlungsauftraege = json.load(infile1) - client, bss_config = bss_client_with_default_auth + client, bss_config = bss_client_with_basic_auth with aioresponses() as mocked_bss: # pylint: disable=line-too-long mocked_get_url = f"{bss_config.server_url}api/Aufgabe/ermittlungsauftraege?marktlokationid=52671494807&includeDetails=true" @@ -56,11 +58,11 @@ async def test_get_ermittlungsauftraege_by_malo(self, bss_client_with_default_au assert len(actual) == 1 assert all(isinstance(x, Ermittlungsauftrag) for x in actual) - async def test_get_stats(self, bss_client_with_default_auth): + async def test_get_stats(self, bss_client_with_basic_auth): stats_json_file = Path(__file__).parent / "example_data" / "aufgabe_stats.json" with open(stats_json_file, "r", encoding="utf-8") as infile: stats = json.load(infile) - client, bss_config = bss_client_with_default_auth + client, bss_config = bss_client_with_basic_auth with aioresponses() as mocked_bss: mocked_get_url = f"{bss_config.server_url}api/Aufgabe/stats" mocked_bss.get(mocked_get_url, status=200, payload=stats) @@ -69,11 +71,11 @@ async def test_get_stats(self, bss_client_with_default_auth): assert actual.stats["Ermittlungsauftrag"]["status"]["Beendet"] == 2692 assert actual.get_sum("Ermittlungsauftrag") == 11518 - async def test_get_all_ermittlungsauftraege(self, bss_client_with_default_auth): + async def test_get_all_ermittlungsauftraege(self, bss_client_with_basic_auth): ermittlungsauftraege_json_file = Path(__file__).parent / "example_data" / "list_of_1_ermittlungsauftraege.json" with open(ermittlungsauftraege_json_file, "r", encoding="utf-8") as infile: ermittlungsauftraege = json.load(infile) - client, bss_config = bss_client_with_default_auth + client, bss_config = bss_client_with_basic_auth stats_mock = Mock(AufgabeStats) def return_345(t): @@ -104,3 +106,19 @@ def return_345(t): assert isinstance(actual, list) assert len(actual) == 345 assert all(isinstance(x, Ermittlungsauftrag) for x in actual) + + async def test_get_stats_with_oauth(self, bss_client_with_oauth): + client, bss_config = bss_client_with_oauth + stats_json_file = Path(__file__).parent / "example_data" / "aufgabe_stats.json" + with open(stats_json_file, "r", encoding="utf-8") as infile: + stats = json.load(infile) + with aioresponses() as mocked_bss: + mocked_get_url = f"{bss_config.server_url}api/Aufgabe/stats" + mocked_bss.get(mocked_get_url, status=200, payload=stats) + try: + actual = await client.get_aufgabe_stats() + except httpx.ConnectError: + pytest.skip("Someone should add good tests for the oauth part, but it's not me and not today") + # https://github.com/Hochfrequenz/bssclient.py/issues/25 + assert isinstance(actual, AufgabeStats) + assert actual.stats["Ermittlungsauftrag"]["status"]["Beendet"] == 2692