From cb83eec2849febf6edb645f3557b4d26c6dba9c4 Mon Sep 17 00:00:00 2001 From: Alexander Guschin <1aguschin@gmail.com> Date: Thu, 22 Dec 2022 17:26:08 +0600 Subject: [PATCH] Fix bugs (#324) - close https://github.com/iterative/gto/issues/298 - close https://github.com/iterative/gto/issues/285 - close https://github.com/iterative/gto/issues/274 --- gto/api.py | 30 ++++++++++---- gto/base.py | 6 ++- gto/cli.py | 14 ++++++- gto/constants.py | 15 ++++--- gto/registry.py | 8 +++- gto/tag.py | 4 +- .../sample_remote_repo_expected_registry.json | 21 ++++------ tests/test_api.py | 38 +++++++++++++++++ tests/test_cli.py | 3 ++ tests/test_constants.py | 41 +++++++++++++++++++ 10 files changed, 146 insertions(+), 34 deletions(-) create mode 100644 tests/test_constants.py diff --git a/gto/api.py b/gto/api.py index fb1e4a10..36b24ea0 100644 --- a/gto/api.py +++ b/gto/api.py @@ -12,6 +12,7 @@ VERSION, VERSIONS_PER_STAGE, VersionSort, + is_hexsha, mark_artifact_unregistered, parse_shortcut, ) @@ -275,7 +276,7 @@ def deprecate( name=name, message=message, stdout=stdout, - simple=simple if simple is not None else True, + simple=simple, force=force, delete=delete, push=push or is_url_of_remote_repo(repo), @@ -332,6 +333,7 @@ def show( all_commits=False, truncate_hexsha=False, registered_only=False, + deprecated=False, assignments_per_version=ASSIGNMENTS_PER_VERSION, versions_per_stage=VERSIONS_PER_STAGE, sort=VersionSort.Timestamp, @@ -344,6 +346,7 @@ def show( all_branches=all_branches, all_commits=all_commits, registered_only=registered_only, + deprecated=deprecated, assignments_per_version=assignments_per_version, versions_per_stage=versions_per_stage, sort=sort, @@ -356,6 +359,7 @@ def show( all_branches=all_branches, all_commits=all_commits, registered_only=registered_only, + deprecated=deprecated, assignments_per_version=assignments_per_version, versions_per_stage=versions_per_stage, sort=sort, @@ -370,6 +374,7 @@ def _show_registry( all_branches=False, all_commits=False, registered_only=False, + deprecated=False, assignments_per_version: int = None, versions_per_stage: int = None, sort: VersionSort = None, @@ -406,11 +411,13 @@ def format_hexsha(hexsha): for name in stages }, "registered": o.is_registered, + "active": o.is_active, } for o in reg.get_artifacts( all_branches=all_branches, all_commits=all_commits, ).values() + if o.is_active or deprecated } if not table: @@ -427,7 +434,7 @@ def format_hexsha(hexsha): + [d["stage"][name] for name in stages], ), ) - for name, d in models_state.items() + for name, d in sorted(models_state.items()) ], "keys" @@ -438,6 +445,7 @@ def _show_versions( # pylint: disable=too-many-locals all_branches=False, all_commits=False, registered_only=False, + deprecated=False, assignments_per_version: int = None, versions_per_stage: int = None, sort: VersionSort = None, @@ -447,7 +455,7 @@ def _show_versions( # pylint: disable=too-many-locals """List versions of artifact""" def format_hexsha(hexsha): - return hexsha[:7] if truncate_hexsha else hexsha + return hexsha[:7] if truncate_hexsha and is_hexsha(hexsha) else hexsha shortcut = parse_shortcut(name) @@ -468,7 +476,9 @@ def format_hexsha(hexsha): ) versions = [] for v in artifact.get_versions( - include_non_explicit=not registered_only, include_discovered=True + active_only=not deprecated, + include_non_explicit=not registered_only, + include_discovered=True, ): v = v.dict_state() v["stages"] = [ @@ -477,13 +487,14 @@ def format_hexsha(hexsha): for vstage in vstages if vstage.version == v["version"] ] - versions.append(v) + if artifact.is_active or deprecated: + versions.append(v) if shortcut.latest: versions = versions[:1] - if shortcut.version: + elif shortcut.version: versions = [v for v in versions if shortcut.version == v["version"]] - if shortcut.stage: + elif shortcut.stage: versions = [ v for v in versions for a in v["stages"] if shortcut.stage == a["stage"] ] @@ -506,6 +517,7 @@ def format_hexsha(hexsha): ) ) v["commit_hexsha"] = format_hexsha(v["commit_hexsha"]) + v["ref"] = format_hexsha(v["ref"]) if len(v["registrations"]) > 1: raise NotImplementedInGTO( "Multiple registrations are not supported currently. How you got in here?" @@ -523,7 +535,7 @@ def format_hexsha(hexsha): def history( repo: Union[str, Repo], - artifact: str = None, + artifact: Optional[str] = None, # action: str = None, all_branches=False, all_commits=False, @@ -538,7 +550,7 @@ def history( ) def format_hexsha(hexsha): - return hexsha[:7] if truncate_hexsha else hexsha + return hexsha[:7] if truncate_hexsha and is_hexsha(hexsha) else hexsha events = [ OrderedDict( diff --git a/gto/base.py b/gto/base.py index 2913afad..4480f414 100644 --- a/gto/base.py +++ b/gto/base.py @@ -426,7 +426,7 @@ def activated_at(self): @property def is_registered(self): - """Tells if this is an a registered artifact - i.e. there Git tags for it""" + """Tells if this is an a registered artifact - i.e. there are Git tags for it""" return not all( isinstance(e, Commit) for e in self.get_events(direct=True, indirect=True) ) @@ -442,6 +442,7 @@ def __repr__(self) -> str: def get_versions( self, + active_only=True, include_non_explicit=False, include_discovered=False, sort=VersionSort.SemVer, @@ -450,7 +451,8 @@ def get_versions( versions = [ v for v in self.versions - if v.is_active + if not active_only + or v.is_active and ( (v.is_registered and not v.discovered) or (include_discovered and v.discovered) diff --git a/gto/cli.py b/gto/cli.py index 9fe32801..be2d9ee2 100644 --- a/gto/cli.py +++ b/gto/cli.py @@ -212,7 +212,7 @@ def GTOGroupSection(section): arg_version = Argument(..., help="Artifact version") arg_stage = Argument(..., help="Stage to assign") option_version = Option(None, "--version", help="Version to register") -option_stage = Option(None, "--stage", help="Stage to assign") +option_stage = Option(..., "--stage", help="Stage to assign") option_to_version = Option( None, "--to-version", help="Version to use for stage assignment" ) @@ -293,11 +293,18 @@ def callback_sort( # pylint: disable=inconsistent-return-statements option_all = Option(False, "--all", "-a", help="Return all versions sorted") option_registered_only = Option( False, - "--registered-only", "--ro", + "--registered-only", is_flag=True, help="Show only registered versions", ) +option_deprecated = Option( + False, + "-d", + "--deprecated", + is_flag=True, + help="Include deprecated in output", +) option_expected = Option( False, "-e", @@ -794,6 +801,7 @@ def show( # pylint: disable=too-many-locals show_stage: bool = option_show_stage, show_ref: bool = option_show_ref, registered_only: bool = option_registered_only, + deprecated: bool = option_deprecated, assignments_per_version: int = option_assignments_per_version, versions_per_stage: int = option_versions_per_stage, sort: str = option_sort, @@ -810,6 +818,7 @@ def show( # pylint: disable=too-many-locals all_branches=all_branches, all_commits=all_commits, registered_only=registered_only, + deprecated=deprecated, assignments_per_version=assignments_per_version, versions_per_stage=versions_per_stage, sort=sort, @@ -823,6 +832,7 @@ def show( # pylint: disable=too-many-locals all_branches=all_branches, all_commits=all_commits, registered_only=registered_only, + deprecated=deprecated, assignments_per_version=assignments_per_version, versions_per_stage=versions_per_stage, sort=sort, diff --git a/gto/constants.py b/gto/constants.py index 155c7fe4..5dd8256f 100644 --- a/gto/constants.py +++ b/gto/constants.py @@ -32,17 +32,22 @@ class Action(Enum): name = "[a-z][a-z0-9-/]*[a-z0-9]" semver = r"(?P0|[1-9]\d*)\.(?P0|[1-9]\d*)\.(?P0|[1-9]\d*)(?:-(?P(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+(?P[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?" counter = "?P[0-9]+" -name_regexp = re.compile(f"^{name}$") -tag_regexp = re.compile( +name_re = re.compile(f"^{name}$") +tag_re = re.compile( f"^(?P{name})(((#(?P{name})|@(?Pv{semver}))(?P!?))|@((?Pdeprecated)|(?Pcreated)))(#({counter}))?$" ) -shortcut_regexp = re.compile( +shortcut_re = re.compile( f"^(?P{name})(#(?P{name})|@(?Platest|greatest|v{semver}))$" ) +git_hexsha_re = re.compile(r"^[0-9a-fA-F]{40}$") + + +def is_hexsha(value): + return bool(git_hexsha_re.search(value)) def check_name_is_valid(value): - return bool(re.search(name_regexp, value)) + return bool(name_re.search(value)) def assert_name_is_valid(value): @@ -62,7 +67,7 @@ class Shortcut(BaseModel): def parse_shortcut(value): - match = re.search(shortcut_regexp, value) + match = re.search(shortcut_re, value) if match: value = match["artifact"] if match["stage"]: diff --git a/gto/registry.py b/gto/registry.py index 913d15a8..77f97bc7 100644 --- a/gto/registry.py +++ b/gto/registry.py @@ -419,7 +419,13 @@ def deprecate( author: Optional[str] = None, author_email: Optional[str] = None, ) -> Optional[Deprecation]: - if not force: + if force: + if simple: + raise WrongArgs("Can't use 'force' with 'simple=True'") + simple = False + else: + if simple is None: + simple = True found_artifact = self.find_artifact(name) if not found_artifact.is_active: raise WrongArgs("Artifact was deprecated already") diff --git a/gto/tag.py b/gto/tag.py index 1e9dc3a6..d3cf3a3b 100644 --- a/gto/tag.py +++ b/gto/tag.py @@ -24,7 +24,7 @@ TAG, VERSION, Action, - tag_regexp, + tag_re, ) from .exceptions import ( InvalidTagName, @@ -79,7 +79,7 @@ def name_tag( def parse_name(name: str, raise_on_fail: bool = True): - match = re.search(tag_regexp, name) + match = re.search(tag_re, name) if raise_on_fail and not match: raise InvalidTagName(name) if match: diff --git a/tests/resources/sample_remote_repo_expected_registry.json b/tests/resources/sample_remote_repo_expected_registry.json index a6ca962c..6b253ece 100644 --- a/tests/resources/sample_remote_repo_expected_registry.json +++ b/tests/resources/sample_remote_repo_expected_registry.json @@ -1,13 +1,4 @@ { - "deprecated-model": { - "version": "v0.0.1", - "stage": { - "dev": null, - "prod": null, - "staging": null - }, - "registered": true - }, "churn": { "version": "v3.1.0", "stage": { @@ -15,7 +6,8 @@ "prod": "v3.0.0", "staging": "v3.1.0" }, - "registered": true + "registered": true, + "active": true }, "segment": { "version": "v0.4.1", @@ -24,7 +16,8 @@ "prod": null, "staging": null }, - "registered": true + "registered": true, + "active": true }, "cv-class": { "version": "v0.1.13", @@ -33,7 +26,8 @@ "prod": null, "staging": null }, - "registered": true + "registered": true, + "active": true }, "model-a": { "version": "v0.0.2", @@ -42,6 +36,7 @@ "prod": null, "staging": null }, - "registered": true + "registered": true, + "active": true } } diff --git a/tests/test_api.py b/tests/test_api.py index 82eeec51..d2884f15 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from pathlib import Path from tempfile import TemporaryDirectory +from time import sleep from typing import Callable, Optional, Tuple from unittest.mock import call, patch @@ -154,10 +155,15 @@ def test_register_deregister(repo_with_artifact): assert latest.author == author assert latest.author_email == author_email + assert len(gto.api.show(repo.working_dir, name, deprecated=False)) == 2 + gto.api.deregister(repo=repo.working_dir, name=name, version=vname2) latest = gto.api.find_latest_version(repo.working_dir, name) assert latest.version == vname1 + assert len(gto.api.show(repo.working_dir, name, deprecated=False)) == 1 + assert len(gto.api.show(repo.working_dir, name, deprecated=True)) == 2 + def test_assign(repo_with_artifact: Tuple[git.Repo, str]): repo, name = repo_with_artifact @@ -226,6 +232,38 @@ def test_assign_force_is_needed(repo_with_artifact: Tuple[git.Repo, str]): gto.api.assign(repo, name, "staging", ref="HEAD^1", force=True) +def test_unassign(repo_with_artifact): + repo, _ = repo_with_artifact + gto.api.register(repo.working_dir, name="model", ref="HEAD") + gto.api.assign(repo.working_dir, name="model", ref="HEAD", stage="dev") + assert ( + gto.api.find_versions_in_stage(repo.working_dir, name="model", stage="dev") + is not None + ) + + gto.api.unassign(repo.working_dir, name="model", ref="HEAD", stage="dev") + assert ( + gto.api.find_versions_in_stage(repo.working_dir, name="model", stage="dev") + is None + ) + + +def test_deprecate(repo_with_artifact): + repo, _ = repo_with_artifact + gto.api.register(repo.working_dir, name="model", ref="HEAD") + assert len(gto.api.show(repo.working_dir, "model")) == 1 + + sleep(1) + gto.api.deprecate(repo.working_dir, name="model") + assert len(gto.api.show(repo.working_dir, "model", deprecated=False)) == 0 + assert len(gto.api.show(repo.working_dir, "model", deprecated=True)) == 1 + + with pytest.raises(WrongArgs): + gto.api.deprecate(repo.working_dir, name="model") + gto.api.deprecate(repo.working_dir, name="model", simple=True, force=True) + gto.api.deprecate(repo.working_dir, name="model", force=True) + + @contextmanager def environ(**overrides): old = {name: os.environ[name] for name in overrides if name in os.environ} diff --git a/tests/test_cli.py b/tests/test_cli.py index 56b2d884..d6d35baa 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -155,6 +155,9 @@ def test_commands(showcase): _check_successful_cmd( "describe", ["-r", path, "rf@latest"], EXPECTED_DESCRIBE_OUTPUT ) + _check_successful_cmd( + "describe", ["-r", path, "rf@v1.2.3"], EXPECTED_DESCRIBE_OUTPUT + ) _check_successful_cmd( "describe", ["-r", path, "rf", "--path"], "models/random-forest.pkl\n" ) diff --git a/tests/test_constants.py b/tests/test_constants.py new file mode 100644 index 00000000..d7149dcf --- /dev/null +++ b/tests/test_constants.py @@ -0,0 +1,41 @@ +import pytest + +from gto.constants import check_name_is_valid + + +@pytest.mark.parametrize( + "name", + [ + "nn", + "m1", + "model-prod", + "model-prod-v1", + "namespace/model", + ], +) +def test_check_name_is_valid(name): + assert check_name_is_valid(name) + + +@pytest.mark.parametrize( + "name", + [ + "", + "m", + "1", + "m/", + "/m", + "1nn", + "###", + "@@@", + "a model", + "a_model", + "-model", + "model-", + "model@1", + "model#1", + "@namespace/model", + ], +) +def test_check_name_is_invalid(name): + assert not check_name_is_valid(name)