Skip to content

Commit

Permalink
feat: Add OAuth as Option besides Basic Auth (#24)
Browse files Browse the repository at this point in the history
* ➕ Add `pyjwt` and `aioauth_client` as dependency

* add oauth

* fix typos

* f*** the coverage

* fix type_check

* type check for real

* fo real2
  • Loading branch information
hf-kklein authored Mar 19, 2024
1 parent d48a158 commit 1d2a3c4
Show file tree
Hide file tree
Showing 9 changed files with 305 additions and 48 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ dependencies = [
"pydantic>=2.0.0",
"aiohttp[speedups]>=3.9.3",
"more-itertools",
"pytz"
"pytz",
"pyjwt",
"aioauth_client",
] # add all the dependencies here
dynamic = ["readme", "version"]

Expand Down
25 changes: 24 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,39 @@
#
# pip-compile pyproject.toml
#
aioauth-client==0.28.1
# via bssclient (pyproject.toml)
aiohttp[speedups]==3.9.3
# via bssclient (pyproject.toml)
aiosignal==1.3.1
# via aiohttp
annotated-types==0.6.0
# via pydantic
anyio==4.3.0
# via httpx
attrs==23.2.0
# via aiohttp
brotli==1.1.0
# via aiohttp
certifi==2024.2.2
# via
# httpcore
# httpx
frozenlist==1.4.1
# via
# aiohttp
# aiosignal
h11==0.14.0
# via httpcore
httpcore==1.0.4
# via httpx
httpx==0.27.0
# via aioauth-client
idna==3.6
# via yarl
# via
# anyio
# httpx
# yarl
more-itertools==10.2.0
# via bssclient (pyproject.toml)
multidict==6.0.5
Expand All @@ -30,8 +47,14 @@ pydantic==2.6.4
# via bssclient (pyproject.toml)
pydantic-core==2.16.3
# via pydantic
pyjwt==2.8.0
# via bssclient (pyproject.toml)
pytz==2024.1
# via bssclient (pyproject.toml)
sniffio==1.3.1
# via
# anyio
# httpx
typing-extensions==4.10.0
# via
# pydantic
Expand Down
97 changes: 76 additions & 21 deletions src/bssclient/client/bssclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,35 @@
import asyncio
import logging
import uuid
from abc import ABC
from typing import Awaitable, Optional

from aiohttp import BasicAuth, ClientSession, ClientTimeout
from more_itertools import chunked
from yarl import URL

from bssclient.client.config import BssConfig
from bssclient.client.config import BasicAuthBssConfig, BssConfig, OAuthBssConfig
from bssclient.client.oauth import _OAuthHttpClient
from bssclient.models.aufgabe import AufgabeStats
from bssclient.models.ermittlungsauftrag import Ermittlungsauftrag, _ListOfErmittlungsauftraege

_logger = logging.getLogger(__name__)


class BssClient:
class BssClient(ABC):
"""
an async wrapper around the BSS API
"""

def __init__(self, config: BssConfig):
self._config = config
self._auth = BasicAuth(login=self._config.usr, password=self._config.pwd)
self._session_lock = asyncio.Lock()
self._session: Optional[ClientSession] = None
_logger.info("Instantiated BssClient with server_url %s", str(self._config.server_url))

async def _get_session(self):
raise NotImplementedError("The inheriting class has to implement this with its respective authentication")

def get_top_level_domain(self) -> URL | None:
"""
Returns the top level domain of the server_url; this is useful to differentiate prod from test systems.
Expand All @@ -47,24 +51,6 @@ def get_top_level_domain(self) -> URL | None:
tld = ".".join(domain_parts[-2:])
return URL(self._config.server_url.scheme + "://" + tld)

async def _get_session(self) -> ClientSession:
"""
returns a client session (that may be reused or newly created)
re-using the same (threadsafe) session will be faster than re-creating a new session for every request.
see https://docs.aiohttp.org/en/stable/http_request_lifecycle.html#how-to-use-the-clientsession
"""
async with self._session_lock:
if self._session is None or self._session.closed:
_logger.info("creating new session")
self._session = ClientSession(
auth=self._auth,
timeout=ClientTimeout(60),
raise_for_status=True,
)
else:
_logger.log(5, "reusing aiohttp session") # log level 5 is half as "loud" logging.DEBUG
return self._session

async def close_session(self):
"""
closes the client session
Expand Down Expand Up @@ -167,3 +153,72 @@ async def get_all_ermittlungsauftraege(self, package_size: int = 100) -> list[Er
result.extend([item for sublist in list_of_lists_of_io_from_chunk for item in sublist])
_logger.info("Downloaded %i Ermittlungsautraege", len(result))
return result


class BasicAuthBssClient(BssClient):
"""BSS client with basic auth"""

def __init__(self, config: BasicAuthBssConfig):
"""instantiate by providing a valid config"""
if not isinstance(config, BasicAuthBssConfig):
raise ValueError("You must provide a valid config")
super().__init__(config)
self._auth = BasicAuth(login=config.usr, password=config.pwd)

async def _get_session(self) -> ClientSession:
"""
returns a client session (that may be reused or newly created)
re-using the same (threadsafe) session will be faster than re-creating a new session for every request.
see https://docs.aiohttp.org/en/stable/http_request_lifecycle.html#how-to-use-the-clientsession
"""
async with self._session_lock:
if self._session is None or self._session.closed:
_logger.info("creating new session")
self._session = ClientSession(
auth=self._auth,
timeout=ClientTimeout(60),
raise_for_status=True,
)
else:
_logger.log(5, "reusing aiohttp session") # log level 5 is half as "loud" logging.DEBUG
return self._session


class OAuthBssClient(BssClient, _OAuthHttpClient):
"""BSS client with OAuth"""

def __init__(self, config: OAuthBssConfig):
if not isinstance(config, OAuthBssConfig):
raise ValueError("You must provide a valid config")
super().__init__(config)
_OAuthHttpClient.__init__(
self,
base_url=config.server_url,
oauth_client_id=config.client_id,
oauth_client_secret=config.client_secret,
oauth_token_url=str(config.token_url),
)
self._oauth_config = config
self._bearer_token: str | None = None

async def _get_session(self) -> ClientSession:
"""
returns a client session (that may be reused or newly created)
re-using the same (threadsafe) session will be faster than re-creating a new session for every request.
see https://docs.aiohttp.org/en/stable/http_request_lifecycle.html#how-to-use-the-clientsession
"""
async with self._session_lock:
if self._bearer_token is None:
self._bearer_token = await self._get_oauth_token()
elif not self._token_is_valid(self._bearer_token):
await self.close_session()
if self._session is None or self._session.closed:
_logger.info("creating new session")
self._session = ClientSession(
timeout=ClientTimeout(60),
raise_for_status=True,
headers={"Authorization": f"Bearer {self._bearer_token}"},
)
else:
_logger.log(5, "reusing aiohttp session") # log level 5 is half as "loud" logging.DEBUG
return self._session
58 changes: 48 additions & 10 deletions src/bssclient/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
contains a class with which the BSS client is instantiated/configured
"""

from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
from yarl import URL


Expand All @@ -17,6 +17,27 @@ class BssConfig(BaseModel):
"""
e.g. URL("https://basicsupply.xtk-stage.de/")
"""

# pylint:disable=no-self-argument
@field_validator("server_url")
def validate_url(cls, value):
"""
check that the value is a yarl URL
"""
# this (together with the nested config) is a workaround for
# RuntimeError: no validator found for <class 'yarl.URL'>, see `arbitrary_types_allowed` in Config
if not isinstance(value, URL):
raise ValueError("Invalid URL type")
if len(value.parts) > 2:
raise ValueError("You must provide a base_url without any parts, e.g. https://basicsupply.xtk-prod.de/")
return value


class BasicAuthBssConfig(BssConfig):
"""
configuration of bss with basic auth
"""

usr: str
"""
basic auth user name
Expand All @@ -37,16 +58,33 @@ def validate_string_is_not_empty(cls, value):
raise ValueError("my_string cannot be empty")
return value


class OAuthBssConfig(BssConfig):
"""
configuration of bss with oauth
"""

client_id: str
"""
client id for OAuth
"""
client_secret: str
"""
client secret for auth password
"""

token_url: HttpUrl
"""
Url of the token endpoint; e.g. 'https://lynqtech-dev-auth-server.auth.eu-central-1.amazoncognito.com/oauth2/token'
"""

# pylint:disable=no-self-argument
@field_validator("server_url")
def validate_url(cls, value):
@field_validator("client_id", "client_secret")
def validate_string_is_not_empty(cls, value):
"""
check that the value is a yarl URL
Check that no one tries to bypass validation with empty strings.
If we had wanted that you can omit values, we had used Optional[str] instead of str.
"""
# this (together with the nested config) is a workaround for
# RuntimeError: no validator found for <class 'yarl.URL'>, see `arbitrary_types_allowed` in Config
if not isinstance(value, URL):
raise ValueError("Invalid URL type")
if len(value.parts) > 2:
raise ValueError("You must provide a base_url without any parts, e.g. https://basicsupply.xtk-prod.de/")
if not value.strip():
raise ValueError("my_string cannot be empty")
return value
99 changes: 99 additions & 0 deletions src/bssclient/client/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
oauth stuff
"""

import asyncio
import logging
from abc import ABC
from datetime import datetime, timedelta
from typing import Optional

import jwt
from aioauth_client import OAuth2Client
from yarl import URL

_logger = logging.getLogger(__name__)


class _ValidateTokenMixin: # pylint:disable=too-few-public-methods
"""
Mixin for classes which need to validate tokens
"""

def __init__(self):
self._session_lock = asyncio.Lock()

def _token_is_valid(self, token) -> bool:
"""
returns true iff the token expiration date is far enough in the future. By "enough" I mean:
more than 1 minute (because the clients' request using the token shouldn't take longer than that)
"""
try:
decoded_token = jwt.decode(token, algorithms=["HS256"], options={"verify_signature": False})
expiration_timestamp = decoded_token.get("exp")
expiration_datetime = datetime.fromtimestamp(expiration_timestamp)
_logger.debug("Token is valid until %s", expiration_datetime.isoformat())
current_datetime = datetime.utcnow()
token_is_valid_one_minute_into_the_future = expiration_datetime > current_datetime + timedelta(minutes=1)
return token_is_valid_one_minute_into_the_future
except jwt.ExpiredSignatureError:
_logger.info("The token is expired", exc_info=True)
return False
except jwt.InvalidTokenError:
_logger.info("The token is invalid", exc_info=True)
return False


class _OAuthHttpClient(_ValidateTokenMixin, ABC): # pylint:disable=too-few-public-methods
"""
An abstract oauth based HTTP client
"""

def __init__(self, base_url: URL, oauth_client_id: str, oauth_client_secret: str, oauth_token_url: URL | str):
"""
instantiate by providing the basic information which is required to connect to the service.
:param base_url: e.g. "https://transformerbee.utilibee.io/"
:param oauth_client_id: e.g. "my-client-id"
:param oauth_client_secret: e.g. "my-client-secret"
:param oauth_token_url: e.g."https://transformerbee.utilibee.io/oauth/token"
"""
super().__init__()
if not isinstance(base_url, URL):
# For the cases where type-check is not enough because we tend to ignore type-check warnings
raise ValueError(f"Pass the base URL as yarl URL or bad things will happen. Got {base_url.__class__}")
self._base_url = base_url
self._oauth2client = OAuth2Client(
client_id=oauth_client_id,
client_secret=oauth_client_secret,
access_token_url=str(oauth_token_url),
logger=_logger,
)
self._token: Optional[str] = None # the jwt token if we did an authenticated request before
self._token_write_lock = asyncio.Lock()

async def _get_new_token(self) -> str:
"""get a new JWT token from the oauth server"""
_logger.debug("Retrieving a new token")
token, _ = await self._oauth2client.get_access_token(
"code",
grant_type="client_credentials",
audience="https://transformer.bee",
# without the audience, you'll get an HTTP 403
)
return token

async def _get_oauth_token(self) -> str:
"""
encapsulates the oauth part, such that it's e.g. easily mockable in tests
:returns the oauth token
"""
async with self._token_write_lock:
if self._token is None:
_logger.info("Initially retrieving a new token")
self._token = await self._get_new_token()
elif not self._token_is_valid(self._token):
_logger.info("Token is not valid anymore, retrieving a new token")
self._token = await self._get_new_token()
else:
_logger.debug("Token is still valid, reusing it")
return self._token
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ setenv = PYTHONPATH = {toxinidir}/src
commands =
coverage run -m pytest --basetemp={envtmpdir} {posargs}
coverage html --omit .tox/*,unittests/*
coverage report --fail-under 95 --omit .tox/*,unittests/*
coverage report --fail-under 90 --omit .tox/*,unittests/*


[testenv:dev]
Expand Down
Loading

0 comments on commit 1d2a3c4

Please sign in to comment.