Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customizable token cache #759

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions msal/application.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import functools
import json
import time
Expand Down Expand Up @@ -238,6 +239,10 @@ class ClientApplication(object):
"You can enable broker by following these instructions. "
"https://msal-python.readthedocs.io/en/latest/#publicclientapplication")

_TOKEN_CACHE_DATA: dict[str, str] = { # field_in_data: field_in_cache
"key_id": "key_id", # Some token types (SSH-certs, POP) are bound to a key
}

def __init__(
self, client_id,
client_credential=None, authority=None, validate_authority=True,
Expand Down Expand Up @@ -651,6 +656,7 @@ def __init__(

self._decide_broker(allow_broker, enable_pii_log)
self.token_cache = token_cache or TokenCache()
self.token_cache._set(data_to_at=self._TOKEN_CACHE_DATA)
self._region_configured = azure_region
self._region_detected = None
self.client, self._regional_client = self._build_client(
Expand Down Expand Up @@ -1528,9 +1534,10 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
"realm": authority.tenant,
"home_account_id": (account or {}).get("home_account_id"),
}
key_id = kwargs.get("data", {}).get("key_id")
if key_id: # Some token types (SSH-certs, POP) are bound to a key
query["key_id"] = key_id
for field_in_data, field_in_cache in self._TOKEN_CACHE_DATA.items():
value = kwargs.get("data", {}).get(field_in_data)
if value:
query[field_in_cache] = value
now = time.time()
refresh_reason = msal.telemetry.AT_ABSENT
for entry in self.token_cache.search( # A generator allows us to
Expand Down
36 changes: 30 additions & 6 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from __future__ import annotations
import json
import threading
import time
from typing import Optional # Needed in Python 3.7 & 3.8
import logging
import warnings

Expand Down Expand Up @@ -39,6 +41,25 @@ class AuthorityType:
ADFS = "ADFS"
MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA

_data_to_at: dict[str, str] = { # field_in_data: field_in_cache
# Store extra data which we explicitly allow,
# so that we won't accidentally store a user's password etc.
# It can be used to store for example key_id used in SSH-cert or POP
}
_response_to_at: dict[str, str] = { # field_in_response: field_in_cache
}

def _set(
self,
*,
data_to_at: Optional[dict[str, str]] = None,
response_to_at: Optional[dict[str, str]] = None,
) -> None:
# This helper should probably be better in __init__(),
# but there is no easy way for MSAL EX to pick up a kwargs
self._data_to_at = data_to_at or {}
self._response_to_at = response_to_at or {}

def __init__(self):
self._lock = threading.RLock()
self._cache = {}
Expand Down Expand Up @@ -267,11 +288,14 @@ def __add(self, event, now=None):
"expires_on": str(now + expires_in), # Same here
"extended_expires_on": str(now + ext_expires_in) # Same here
}
at.update({k: data[k] for k in data if k in {
# Also store extra data which we explicitly allow
# So that we won't accidentally store a user's password etc.
"key_id", # It happens in SSH-cert or POP scenario
}})
for field_in_resp, field_in_cache in self._response_to_at.items():
value = response.get(field_in_resp)
if value:
at[field_in_cache] = value
for field_in_data, field_in_cache in self._data_to_at.items():
value = data.get(field_in_data)
if value:
at[field_in_cache] = value
if "refresh_in" in response:
refresh_in = response["refresh_in"] # It is an integer
at["refresh_on"] = str(now + refresh_in) # Schema wants a string
Expand Down
1 change: 1 addition & 0 deletions tests/test_token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def assertFoundAccessToken(self, *, scopes, query, data=None, now=None):
def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
scopes = ["s2", "s1", "s3"] # Not in particular order
now = 1000
self.cache._set(data_to_at={"key_id": "key_id"})
self.cache.add({
"data": data,
"client_id": "my_client_id",
Expand Down