diff --git a/pypitoken/restrictions.py b/pypitoken/restrictions.py index c9849ad..8c97a72 100644 --- a/pypitoken/restrictions.py +++ b/pypitoken/restrictions.py @@ -371,11 +371,16 @@ def check(self, context: Context) -> None: @classmethod def from_parameters( cls, - project_names: list[str] | None = None, + project_names: Iterable[str] | None = None, **kwargs, ) -> ProjectNamesRestriction | None: if project_names is not None: - return cls(project_names=project_names) + if isinstance(project_names, str): + raise exceptions.InvalidRestriction( + "project_names should be an iterable of strings. " + "Received a single string not wrapped in an iterable." + ) + return cls(project_names=list(project_names)) return None @@ -435,11 +440,16 @@ def check(self, context: Context) -> None: @classmethod def from_parameters( cls, - project_ids: list[str] | None = None, + project_ids: Iterable[str] | None = None, **kwargs, ) -> ProjectIDsRestriction | None: if project_ids is not None: - return cls(project_ids=project_ids) + if isinstance(project_ids, str): + raise exceptions.InvalidRestriction( + "project_ids should be an iterable of strings. " + "Received a single string not wrapped in an iterable." + ) + return cls(project_ids=list(project_ids)) return None @@ -598,11 +608,16 @@ def check(self, context: Context) -> None: @classmethod def from_parameters( cls, - legacy_project_names: list[str] | None = None, + legacy_project_names: Iterable[str] | None = None, **kwargs, ) -> LegacyProjectNamesRestriction | None: if legacy_project_names is not None: - return cls(project_names=legacy_project_names) + if isinstance(legacy_project_names, str): + raise exceptions.InvalidRestriction( + "legacy_project_names should be an iterable of strings. " + "Received a single string not wrapped in an iterable." + ) + return cls(project_names=list(legacy_project_names)) return None diff --git a/pypitoken/token.py b/pypitoken/token.py index 4c8457b..869af59 100644 --- a/pypitoken/token.py +++ b/pypitoken/token.py @@ -2,6 +2,7 @@ import datetime import functools +from typing import Iterable import pymacaroons from typing_extensions import ParamSpec @@ -152,11 +153,11 @@ def restrict( self, not_before: int | datetime.datetime | None = None, not_after: int | datetime.datetime | None = None, - project_names: list[str] | None = None, - project_ids: list[str] | None = None, + project_names: Iterable[str] | None = None, + project_ids: Iterable[str] | None = None, user_id: str | None = None, # Legacy params - legacy_project_names: list[str] | None = None, + legacy_project_names: Iterable[str] | None = None, legacy_not_before: int | datetime.datetime | None = None, legacy_not_after: int | datetime.datetime | None = None, legacy_noop: bool | None = None, diff --git a/tests/test_restrictions.py b/tests/test_restrictions.py index bf861d0..6ba85c5 100644 --- a/tests/test_restrictions.py +++ b/tests/test_restrictions.py @@ -198,6 +198,13 @@ def test__LegacyProjectNamesRestriction__from_parameters__not_empty(): ) == restrictions.LegacyProjectNamesRestriction(project_names=["a", "b"]) +def test__LegacyProjectNamesRestriction__from_parameters__bare_string(): + with pytest.raises(exceptions.InvalidRestriction): + restrictions.LegacyProjectNamesRestriction.from_parameters( + legacy_project_names="a" + ) + + def test__LegacyDateRestriction__load_value__pass(): assert restrictions.LegacyDateRestriction._load_value( value={"nbf": 1_234_567_890, "exp": 1_234_567_900} @@ -546,6 +553,11 @@ def test__ProjectNamesRestriction__from_parameters__empty(): assert restrictions.ProjectNamesRestriction.from_parameters() is None +def test__ProjectNamesRestriction__from_parameters__bare_string(): + with pytest.raises(exceptions.InvalidRestriction): + restrictions.ProjectNamesRestriction.from_parameters(project_names="a") + + def test__ProjectNamesRestriction__from_parameters__not_empty(): assert restrictions.ProjectNamesRestriction.from_parameters( project_names=["a", "b"] @@ -669,6 +681,13 @@ def test__ProjectIDsRestriction__from_parameters__empty(): assert restrictions.ProjectIDsRestriction.from_parameters() is None +def test__ProjectIDsRestriction__from_parameters__bare_string(): + with pytest.raises(exceptions.InvalidRestriction): + restrictions.ProjectIDsRestriction.from_parameters( + project_ids="00000000-0000-0000-0000-000000000000" + ) + + def test__ProjectIDsRestriction__from_parameters__not_empty(): assert restrictions.ProjectIDsRestriction.from_parameters( project_ids=[