diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index d132a9cb..3bf114a5 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -2,10 +2,11 @@ import fsspec -from rubicon_ml.client import Base +from rubicon_ml.client.base import Base +from rubicon_ml.client.mixin import TagMixin -class Artifact(Base): +class Artifact(Base, TagMixin): """A client artifact. An `artifact` is a catch-all for any other type of @@ -35,7 +36,7 @@ def __init__(self, domain, parent): def _get_data(self): """Loads the data associated with this artifact.""" - project_name, experiment_id = self.parent._get_parent_identifiers() + project_name, experiment_id = self.parent._get_identifiers() self._data = self.repository.get_artifact_data( project_name, self.id, experiment_id=experiment_id diff --git a/rubicon_ml/client/asynchronous/artifact.py b/rubicon_ml/client/asynchronous/artifact.py index 1f3faaae..a2bcadba 100644 --- a/rubicon_ml/client/asynchronous/artifact.py +++ b/rubicon_ml/client/asynchronous/artifact.py @@ -29,7 +29,7 @@ async def _get_data(self): """Overrides `rubicon.client.Artifact._get_data` to asynchronously load the data associated with this artifact. """ - project_name, experiment_id = self.parent._get_parent_identifiers() + project_name, experiment_id = self.parent._get_identifiers() self._data = await self.repository.get_artifact_data( project_name, self.id, experiment_id=experiment_id diff --git a/rubicon_ml/client/asynchronous/mixin.py b/rubicon_ml/client/asynchronous/mixin.py index b64e3319..5c604861 100644 --- a/rubicon_ml/client/asynchronous/mixin.py +++ b/rubicon_ml/client/asynchronous/mixin.py @@ -3,11 +3,10 @@ from rubicon_ml import domain from rubicon_ml.client import asynchronous as client -from rubicon_ml.client.mixin import MultiParentMixin from rubicon_ml.client.utils.tags import has_tag_requirements -class ArtifactMixin(MultiParentMixin): +class ArtifactMixin: """Adds artifact support to an asynchronous client object.""" async def log_artifact( @@ -66,7 +65,7 @@ async def log_artifact( artifact = domain.Artifact(name=name, description=description, parent_id=self._domain.id) - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() await self.repository.create_artifact( artifact, data_bytes, project_name, experiment_id=experiment_id ) @@ -130,7 +129,7 @@ async def artifacts(self): list of rubicon.client.Artifact The artifacts previously logged to this client object. """ - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() self._artifacts = [ client.Artifact(a, self) @@ -151,7 +150,7 @@ async def delete_artifacts(self, ids): ids : list of str The ids of the artifacts to delete. """ - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() await asyncio.gather( *[ @@ -163,7 +162,7 @@ async def delete_artifacts(self, ids): ) -class DataframeMixin(MultiParentMixin): +class DataframeMixin: """Adds dataframe support to an asynchronous client object.""" async def log_dataframe(self, df, description=None, tags=[]): @@ -187,7 +186,7 @@ async def log_dataframe(self, df, description=None, tags=[]): """ dataframe = domain.Dataframe(parent_id=self._domain.id, description=description, tags=tags) - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() await self.repository.create_dataframe( dataframe, df, project_name, experiment_id=experiment_id ) @@ -226,7 +225,7 @@ async def dataframes(self, tags=[], qtype="or"): list of rubicon.client.Dataframe The dataframes previously logged to this client object. """ - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() dataframes = [ client.Dataframe(d, self) for d in await self.repository.get_dataframes_metadata( @@ -248,7 +247,7 @@ async def delete_dataframes(self, ids): ids : list of str The ids of the dataframes to delete. """ - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() await asyncio.gather( *[ diff --git a/rubicon_ml/client/dataframe.py b/rubicon_ml/client/dataframe.py index af158ceb..2b5e6a80 100644 --- a/rubicon_ml/client/dataframe.py +++ b/rubicon_ml/client/dataframe.py @@ -41,7 +41,7 @@ def get_data(self, df_type="pandas"): The type of dataframe. Can be either `pandas` or `dask`. Defaults to 'pandas'. """ - project_name, experiment_id = self.parent._get_parent_identifiers() + project_name, experiment_id = self.parent._get_identifiers() self._data = self.repository.get_dataframe_data( project_name, diff --git a/rubicon_ml/client/experiment.py b/rubicon_ml/client/experiment.py index 4561056b..f72a8ec1 100644 --- a/rubicon_ml/client/experiment.py +++ b/rubicon_ml/client/experiment.py @@ -37,6 +37,10 @@ def __init__(self, domain, parent): self._features = [] self._parameters = [] + def _get_identifiers(self): + """Get the experiment's project's name and the experiment's ID.""" + return self.project.name, self.id + def log_metric(self, name, value, directionality="score", description=None): """Create a metric under the experiment. diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index b64f19bb..1eefecc1 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -10,27 +10,7 @@ from rubicon_ml.exceptions import RubiconException -class MultiParentMixin: - """Adds utils for client objects that can be logged - to either a `Project` or `Experiment`. - """ - - def _get_parent_identifiers(self): - """Get the project name and experiment ID (or - `None`) of this client object's parent(s). - """ - experiment_id = None - - if isinstance(self, client.Project): - project_name = self.name - else: - project_name = self.project.name - experiment_id = self.id - - return project_name, experiment_id - - -class ArtifactMixin(MultiParentMixin): +class ArtifactMixin: """Adds artifact support to a client object.""" def _validate_data(self, data_bytes, data_file, data_path, name): @@ -60,7 +40,13 @@ def _validate_data(self, data_bytes, data_file, data_path, name): return data_bytes, name def log_artifact( - self, data_bytes=None, data_file=None, data_path=None, name=None, description=None + self, + data_bytes=None, + data_file=None, + data_path=None, + name=None, + description=None, + tags=[], ): """Log an artifact to this client object. @@ -80,6 +66,9 @@ def log_artifact( description : str, optional A description of the artifact. Use to provide additional context. + tags : list of str, optional + Values to tag the experiment with. Use tags to organize and + filter your artifacts. Notes ----- @@ -112,9 +101,14 @@ def log_artifact( """ data_bytes, name = self._validate_data(data_bytes, data_file, data_path, name) - artifact = domain.Artifact(name=name, description=description, parent_id=self._domain.id) + artifact = domain.Artifact( + name=name, + description=description, + parent_id=self._domain.id, + tags=tags, + ) - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() self.repository.create_artifact( artifact, data_bytes, project_name, experiment_id=experiment_id ) @@ -202,7 +196,7 @@ def artifacts(self, name=None): list of rubicon.client.Artifact The artifacts previously logged to this client object. """ - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() self._artifacts = [ client.Artifact(a, self) for a in self.repository.get_artifacts_metadata( @@ -245,7 +239,7 @@ def artifact(self, name=None, id=None): artifact = artifacts[-1] else: - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() artifact = client.Artifact( self.repository.get_artifact_metadata(project_name, id, experiment_id), self ) @@ -261,13 +255,13 @@ def delete_artifacts(self, ids): ids : list of str The ids of the artifacts to delete. """ - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() for artifact_id in ids: self.repository.delete_artifact(project_name, artifact_id, experiment_id=experiment_id) -class DataframeMixin(MultiParentMixin): +class DataframeMixin: """Adds dataframe support to a client object.""" def log_dataframe(self, df, description=None, name=None, tags=[]): @@ -295,7 +289,7 @@ def log_dataframe(self, df, description=None, name=None, tags=[]): tags=tags, ) - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() self.repository.create_dataframe(dataframe, df, project_name, experiment_id=experiment_id) return client.Dataframe(dataframe, self) @@ -336,7 +330,7 @@ def dataframes(self, name=None, tags=[], qtype="or"): list of rubicon.client.Dataframe The dataframes previously logged to this client object. """ - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() dataframes = [ client.Dataframe(d, self) for d in self.repository.get_dataframes_metadata( @@ -379,7 +373,7 @@ def dataframe(self, name=None, id=None): dataframe = dataframes[-1] else: - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() dataframe = client.Dataframe( self.repository.get_dataframe_metadata( project_name, experiment_id=experiment_id, dataframe_id=id @@ -398,7 +392,7 @@ def delete_dataframes(self, ids): ids : list of str The ids of the dataframes to delete. """ - project_name, experiment_id = self._get_parent_identifiers() + project_name, experiment_id = self._get_identifiers() for dataframe_id in ids: self.repository.delete_dataframe( @@ -410,15 +404,16 @@ class TagMixin: """Adds tag support to a client object.""" def _get_taggable_identifiers(self): - dataframe_id = None + project_name, experiment_id = self._parent._get_identifiers() + entity_id = None - if isinstance(self, client.Dataframe): - project_name, experiment_id = self.parent._get_parent_identifiers() - dataframe_id = self.id + # experiments are not required to return an entity ID - they are the entity + if isinstance(self, client.Experiment): + experiment_id = self.id else: - project_name, experiment_id = self._get_parent_identifiers() + entity_id = self.id - return project_name, experiment_id, dataframe_id + return project_name, experiment_id, entity_id def add_tags(self, tags): """Add tags to this client object. @@ -428,11 +423,15 @@ def add_tags(self, tags): tags : list of str The tag values to add. """ - project_name, experiment_id, dataframe_id = self._get_taggable_identifiers() + project_name, experiment_id, entity_id = self._get_taggable_identifiers() self._domain.add_tags(tags) self.repository.add_tags( - project_name, tags, experiment_id=experiment_id, dataframe_id=dataframe_id + project_name, + tags, + experiment_id=experiment_id, + entity_id=entity_id, + entity_type=self.__class__.__name__, ) def remove_tags(self, tags): @@ -443,11 +442,15 @@ def remove_tags(self, tags): tags : list of str The tag values to remove. """ - project_name, experiment_id, dataframe_id = self._get_taggable_identifiers() + project_name, experiment_id, entity_id = self._get_taggable_identifiers() self._domain.remove_tags(tags) self.repository.remove_tags( - project_name, tags, experiment_id=experiment_id, dataframe_id=dataframe_id + project_name, + tags, + experiment_id=experiment_id, + entity_id=entity_id, + entity_type=self.__class__.__name__, ) def _update_tags(self, tag_data): @@ -461,11 +464,12 @@ def _update_tags(self, tag_data): @property def tags(self): """Get this client object's tags.""" - project_name, experiment_id, dataframe_id = self._get_taggable_identifiers() + project_name, experiment_id, entity_id = self._get_taggable_identifiers() tag_data = self.repository.get_tags( project_name, experiment_id=experiment_id, - dataframe_id=dataframe_id, + entity_id=entity_id, + entity_type=self.__class__.__name__, ) self._update_tags(tag_data) diff --git a/rubicon_ml/client/project.py b/rubicon_ml/client/project.py index f2c9f3c3..bbfbead6 100644 --- a/rubicon_ml/client/project.py +++ b/rubicon_ml/client/project.py @@ -49,6 +49,10 @@ def _get_commit_hash(self): return completed_process.stdout.decode("utf8").replace("\n", "") + def _get_identifiers(self): + """Get the project's name.""" + return self.name, None + def _create_experiment_domain( self, name, diff --git a/rubicon_ml/domain/artifact.py b/rubicon_ml/domain/artifact.py index 66da10a9..4a47f59b 100644 --- a/rubicon_ml/domain/artifact.py +++ b/rubicon_ml/domain/artifact.py @@ -2,16 +2,19 @@ from dataclasses import dataclass, field from datetime import datetime +from typing import List +from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import uuid @dataclass -class Artifact: +class Artifact(TagMixin): name: str id: str = field(default_factory=uuid.uuid4) description: str = None created_at: datetime = field(default_factory=datetime.utcnow) + tags: List[str] = field(default_factory=list) parent_id: str = None diff --git a/rubicon_ml/repository/asynchronous/base.py b/rubicon_ml/repository/asynchronous/base.py index bc254bf5..826c8e91 100644 --- a/rubicon_ml/repository/asynchronous/base.py +++ b/rubicon_ml/repository/asynchronous/base.py @@ -819,7 +819,9 @@ async def delete_artifact(self, project_name, artifact_id, experiment_id=None): @connect @invalidate_cache - async def add_tags(self, project_name, tags, experiment_id=None, dataframe_id=None): + async def add_tags( + self, project_name, tags, experiment_id=None, dataframe_id=None, entity_type=None + ): """Overrides `rubicon.repository.BaseRepository.add_tags to asynchronously persist tags to the configured filesystem. @@ -837,14 +839,18 @@ async def add_tags(self, project_name, tags, experiment_id=None, dataframe_id=No The ID of the dataframe to apply the tags `tags` to. """ - tag_metadata_root = self._get_tag_metadata_root(project_name, experiment_id, dataframe_id) + tag_metadata_root = self._get_tag_metadata_root( + project_name, experiment_id, dataframe_id, entity_type + ) tag_metadata_path = f"{tag_metadata_root}/tags_{domain.utils.uuid.uuid4()}.json" await self._persist_domain({"added_tags": tags}, tag_metadata_path) @connect @invalidate_cache - async def remove_tags(self, project_name, tags, experiment_id=None, dataframe_id=None): + async def remove_tags( + self, project_name, tags, experiment_id=None, dataframe_id=None, entity_type=None + ): """Overrides `rubicon.repository.BaseRepository.remove_tags` to asynchronously delete tags from the configured filesystem. @@ -862,13 +868,15 @@ async def remove_tags(self, project_name, tags, experiment_id=None, dataframe_id The ID of the dataframe to delete the tags `tags` from. """ - tag_metadata_root = self._get_tag_metadata_root(project_name, experiment_id, dataframe_id) + tag_metadata_root = self._get_tag_metadata_root( + project_name, experiment_id, dataframe_id, entity_type + ) tag_metadata_path = f"{tag_metadata_root}/tags_{domain.utils.uuid.uuid4()}.json" await self._persist_domain({"removed_tags": tags}, tag_metadata_path) @connect - async def get_tags(self, project_name, experiment_id=None, dataframe_id=None): + async def get_tags(self, project_name, experiment_id=None, dataframe_id=None, entity_type=None): """Overrides `rubicon.repository.BaseRepository.get_tags` to asynchronously retrieve tags from the configured filesystem. @@ -890,7 +898,9 @@ async def get_tags(self, project_name, experiment_id=None, dataframe_id=None): value is a list of tag names that have been added to or removed from the specified object. """ - tag_metadata_root = self._get_tag_metadata_root(project_name, experiment_id, dataframe_id) + tag_metadata_root = self._get_tag_metadata_root( + project_name, experiment_id, dataframe_id, entity_type + ) all_paths = await self.filesystem._lsdir(tag_metadata_root) tag_paths = [p for p in all_paths if "/tags_" in p["name"]] diff --git a/rubicon_ml/repository/base.py b/rubicon_ml/repository/base.py index 0e86975c..b199d8f8 100644 --- a/rubicon_ml/repository/base.py +++ b/rubicon_ml/repository/base.py @@ -913,19 +913,31 @@ def get_parameters(self, project_name, experiment_id): # ---------- Tags ---------- - def _get_tag_metadata_root(self, project_name, experiment_id=None, dataframe_id=None): - if dataframe_id is not None: - dataframe_metadata_root = self._get_dataframe_metadata_root(project_name, experiment_id) + def _get_tag_metadata_root( + self, project_name, experiment_id=None, entity_id=None, entity_type=None + ): + """Returns the directory to write tags to.""" + get_metadata_root_lookup = { + "Artifact": self._get_artifact_metadata_root, + "Dataframe": self._get_dataframe_metadata_root, + "Experiment": self._get_experiment_metadata_root, + } - return f"{dataframe_metadata_root}/{dataframe_id}" - elif experiment_id is not None: - experiment_metadata_root = self._get_experiment_metadata_root(project_name) + try: + get_metadata_root = get_metadata_root_lookup[entity_type] + except KeyError: + raise ValueError("`experiment_id` and `entity_id` can not both be `None`.") + + if entity_type == "Experiment": + experiment_metadata_root = get_metadata_root(project_name) return f"{experiment_metadata_root}/{experiment_id}" else: - raise ValueError("`experiment_id` and `dataframe_id` can not both be `None`.") + entity_metadata_root = get_metadata_root(project_name, experiment_id) - def add_tags(self, project_name, tags, experiment_id=None, dataframe_id=None): + return f"{entity_metadata_root}/{entity_id}" + + def add_tags(self, project_name, tags, experiment_id=None, entity_id=None, entity_type=None): """Persist tags to the configured filesystem. Parameters @@ -938,16 +950,21 @@ def add_tags(self, project_name, tags, experiment_id=None, dataframe_id=None): experiment_id : str, optional The ID of the experiment to apply the tags `tags` to. - dataframe_id : str, optional - The ID of the dataframe to apply the tags + entity_id : str, optional + The ID of the entity to apply the tags `tags` to. + entity_type : str, optional + The name of the entity's type as returned by + `entity_cls.__class__.__name__`. """ - tag_metadata_root = self._get_tag_metadata_root(project_name, experiment_id, dataframe_id) + tag_metadata_root = self._get_tag_metadata_root( + project_name, experiment_id, entity_id, entity_type + ) tag_metadata_path = f"{tag_metadata_root}/tags_{domain.utils.uuid.uuid4()}.json" self._persist_domain({"added_tags": tags}, tag_metadata_path) - def remove_tags(self, project_name, tags, experiment_id=None, dataframe_id=None): + def remove_tags(self, project_name, tags, experiment_id=None, entity_id=None, entity_type=None): """Delete tags from the configured filesystem. Parameters @@ -960,11 +977,16 @@ def remove_tags(self, project_name, tags, experiment_id=None, dataframe_id=None) experiment_id : str, optional The ID of the experiment to delete the tags `tags` from. - dataframe_id : str, optional - The ID of the dataframe to delete the tags - `tags` from. + entity_id : str, optional + The ID of the entity to apply the tags + `tags` to. + entity_type : str, optional + The name of the entity's type as returned by + `entity_cls.__class__.__name__`. """ - tag_metadata_root = self._get_tag_metadata_root(project_name, experiment_id, dataframe_id) + tag_metadata_root = self._get_tag_metadata_root( + project_name, experiment_id, entity_id, entity_type + ) tag_metadata_path = f"{tag_metadata_root}/tags_{domain.utils.uuid.uuid4()}.json" self._persist_domain({"removed_tags": tags}, tag_metadata_path) @@ -983,7 +1005,7 @@ def _sort_tag_paths(self, tag_paths): return tag_paths_with_timestamps - def get_tags(self, project_name, experiment_id=None, dataframe_id=None): + def get_tags(self, project_name, experiment_id=None, entity_id=None, entity_type=None): """Retrieve tags from the configured filesystem. Parameters @@ -993,8 +1015,12 @@ def get_tags(self, project_name, experiment_id=None, dataframe_id=None): tags from belongs to. experiment_id : str, optional The ID of the experiment to retrieve tags from. - dataframe_id : str, optional - The ID of the dataframe to retrieve tags from. + entity_id : str, optional + The ID of the entity to apply the tags + `tags` to. + entity_type : str, optional + The name of the entity's type as returned by + `entity_cls.__class__.__name__`. Returns ------- @@ -1004,7 +1030,9 @@ def get_tags(self, project_name, experiment_id=None, dataframe_id=None): value is a list of tag names that have been added to or removed from the specified object. """ - tag_metadata_root = self._get_tag_metadata_root(project_name, experiment_id, dataframe_id) + tag_metadata_root = self._get_tag_metadata_root( + project_name, experiment_id, entity_id, entity_type + ) tag_metadata_glob = f"{tag_metadata_root}/tags_*.json" tag_paths = self.filesystem.glob(tag_metadata_glob, detail=True) diff --git a/tests/unit/client/test_experiment_client.py b/tests/unit/client/test_experiment_client.py index 619a0595..84da9a03 100644 --- a/tests/unit/client/test_experiment_client.py +++ b/tests/unit/client/test_experiment_client.py @@ -31,6 +31,15 @@ def test_properties(project_client): assert experiment.project == project +def test_get_identifiers(project_client): + project = project_client + experiment = project.log_experiment() + project_name, experiment_id = experiment._get_identifiers() + + assert project_name == project.name + assert experiment_id == experiment.id + + def test_log_metric(project_client): project = project_client experiment = project.log_experiment(name="exp1") diff --git a/tests/unit/client/test_mixin_client.py b/tests/unit/client/test_mixin_client.py index ebc65aa7..7f5e6547 100644 --- a/tests/unit/client/test_mixin_client.py +++ b/tests/unit/client/test_mixin_client.py @@ -4,32 +4,10 @@ import pytest -from rubicon_ml.client.mixin import ( - ArtifactMixin, - DataframeMixin, - MultiParentMixin, - TagMixin, -) +from rubicon_ml.client.mixin import ArtifactMixin, DataframeMixin, TagMixin from rubicon_ml.exceptions import RubiconException -def test_get_project_identifiers(project_client): - project = project_client - project_name, experiment_id = MultiParentMixin._get_parent_identifiers(project) - - assert project_name == project.name - assert experiment_id is None - - -def test_get_experiment_identifiers(project_client): - project = project_client - experiment = project.log_experiment() - project_name, experiment_id = MultiParentMixin._get_parent_identifiers(experiment) - - assert project_name == project.name - assert experiment_id == experiment.id - - # ArtifactMixin def test_log_artifact_from_bytes(project_client): project = project_client @@ -355,14 +333,45 @@ def test_get_taggable_experiment_identifiers(project_client): def test_get_taggable_dataframe_identifiers(project_client, test_dataframe): project = project_client + experiment = project.log_experiment() + df = test_dataframe - logged_df = project.log_dataframe(df) + project_df = project.log_dataframe(df) + experiment_df = experiment.log_dataframe(df) - project_name, experiment_id, dataframe_id = TagMixin._get_taggable_identifiers(logged_df) + project_name, experiment_id, dataframe_id = TagMixin._get_taggable_identifiers(project_df) assert project_name == project.name assert experiment_id is None - assert dataframe_id == logged_df.id + assert dataframe_id == project_df.id + + project_name, experiment_id, dataframe_id = TagMixin._get_taggable_identifiers(experiment_df) + + assert project_name == project.name + assert experiment_id is experiment.id + assert dataframe_id == experiment_df.id + + +def test_get_taggable_artifact_identifiers(project_client): + project = project_client + experiment = project.log_experiment() + + project_artifact = project.log_artifact(data_bytes=b"test", name="test") + experiment_artifact = experiment.log_artifact(data_bytes=b"test", name="test") + + project_name, experiment_id, artifact_id = TagMixin._get_taggable_identifiers(project_artifact) + + assert project_name == project.name + assert experiment_id is None + assert artifact_id == project_artifact.id + + project_name, experiment_id, artifact_id = TagMixin._get_taggable_identifiers( + experiment_artifact + ) + + assert project_name == project.name + assert experiment_id is experiment.id + assert artifact_id == experiment_artifact.id def test_add_tags(project_client): diff --git a/tests/unit/client/test_project_client.py b/tests/unit/client/test_project_client.py index 93111282..dabea9d3 100644 --- a/tests/unit/client/test_project_client.py +++ b/tests/unit/client/test_project_client.py @@ -59,6 +59,14 @@ def test_get_commit_hash(project_client): assert mock_run.mock_calls == expected +def test_get_identifiers(project_client): + project = project_client + project_name, experiment_id = project._get_identifiers() + + assert project_name == project.name + assert experiment_id is None + + def test_create_experiment_with_auto_git(): with mock.patch("subprocess.run") as mock_run: mock_run.return_value = MockCompletedProcess(stdout=b"test", returncode=0) diff --git a/tests/unit/repository/asynchronous/test_asyn_base_repo.py b/tests/unit/repository/asynchronous/test_asyn_base_repo.py index 70910d6d..4eefabf9 100644 --- a/tests/unit/repository/asynchronous/test_asyn_base_repo.py +++ b/tests/unit/repository/asynchronous/test_asyn_base_repo.py @@ -801,7 +801,10 @@ def test_add_tags(asyn_repo_w_mock_filesystem): tags = ["x"] asyncio.run( asyn_repo_w_mock_filesystem.add_tags( - experiment.project_name, tags, experiment_id=experiment.id + experiment.project_name, + tags, + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, ) ) @@ -818,7 +821,10 @@ def test_remove_tags(asyn_repo_w_mock_filesystem): tags = ["x"] asyncio.run( asyn_repo_w_mock_filesystem.remove_tags( - experiment.project_name, tags, experiment_id=experiment.id + experiment.project_name, + tags, + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, ) ) @@ -839,7 +845,11 @@ def test_get_tags(asyn_repo_w_mock_filesystem): asyn_repo_w_mock_filesystem.filesystem._cat_file.return_value = '{"test":"test"}' asyncio.run( - asyn_repo_w_mock_filesystem.get_tags(experiment.project_name, experiment_id=experiment.id) + asyn_repo_w_mock_filesystem.get_tags( + experiment.project_name, + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) ) filesystem_expected = [call._lsdir(ANY), call._cat_file(ANY)] @@ -854,7 +864,11 @@ def test_get_tags_with_no_results(asyn_repo_w_mock_filesystem): asyn_repo_w_mock_filesystem.filesystem._lsdir.return_value = [] tags = asyncio.run( - asyn_repo_w_mock_filesystem.get_tags(experiment.project_name, experiment_id=experiment.id) + asyn_repo_w_mock_filesystem.get_tags( + experiment.project_name, + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) ) filesystem_expected = [call._lsdir(ANY)] diff --git a/tests/unit/repository/test_base_repo.py b/tests/unit/repository/test_base_repo.py index 140a2f00..0e7485d5 100644 --- a/tests/unit/repository/test_base_repo.py +++ b/tests/unit/repository/test_base_repo.py @@ -818,7 +818,9 @@ def test_get_experiment_tags_root(memory_repository): repository = memory_repository experiment = _create_experiment(repository) experiment_tags_root = repository._get_tag_metadata_root( - experiment.project_name, experiment_id=experiment.id + experiment.project_name, + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, ) assert ( @@ -831,7 +833,9 @@ def test_get_dataframe_tags_with_project_parent_root(memory_repository): repository = memory_repository project = _create_project(repository) dataframe = _create_pandas_dataframe(repository, project=project) - dataframe_tags_root = repository._get_tag_metadata_root(project.name, dataframe_id=dataframe.id) + dataframe_tags_root = repository._get_tag_metadata_root( + project.name, entity_id=dataframe.id, entity_type=dataframe.__class__.__name__ + ) assert ( dataframe_tags_root @@ -848,7 +852,10 @@ def test_get_dataframe_tags_with_experiment_parent_root(memory_repository): repository.create_dataframe(dataframe, dataframe_data, experiment.project_name, experiment.id) dataframe_tags_root = repository._get_tag_metadata_root( - experiment.project_name, experiment_id=experiment.id, dataframe_id=dataframe.id + experiment.project_name, + experiment_id=experiment.id, + entity_id=dataframe.id, + entity_type=dataframe.__class__.__name__, ) assert ( @@ -865,13 +872,18 @@ def test_get_root_without_experiment_or_dataframe_throws_error(memory_repository with pytest.raises(ValueError) as e: repository._get_tag_metadata_root(project.name) - assert "`experiment_id` and `dataframe_id` can not both be `None`" in str(e) + assert "`experiment_id` and `entity_id` can not both be `None`" in str(e) def test_add_tags(memory_repository): repository = memory_repository experiment = _create_experiment(repository) - repository.add_tags(experiment.project_name, ["wow"], experiment_id=experiment.id) + repository.add_tags( + experiment.project_name, + ["wow"], + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) tags_glob = f"{repository.root_dir}/{slugify(experiment.project_name)}/experiments/{experiment.id}/tags_*.json" tags_files = repository.filesystem.glob(tags_glob) @@ -888,7 +900,12 @@ def test_add_tags(memory_repository): def test_remove_tags(memory_repository): repository = memory_repository experiment = _create_experiment(repository, tags=["wow"]) - repository.remove_tags(experiment.project_name, ["wow"], experiment_id=experiment.id) + repository.remove_tags( + experiment.project_name, + ["wow"], + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) tags_glob = f"{repository.root_dir}/{slugify(experiment.project_name)}/experiments/{experiment.id}/tags_*.json" tags_files = repository.filesystem.glob(tags_glob) @@ -905,10 +922,24 @@ def test_remove_tags(memory_repository): def test_get_tags(memory_repository): repository = memory_repository experiment = _create_experiment(repository, tags=["wow"]) - repository.add_tags(experiment.project_name, ["cool"], experiment_id=experiment.id) - repository.remove_tags(experiment.project_name, ["wow"], experiment_id=experiment.id) + repository.add_tags( + experiment.project_name, + ["cool"], + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) + repository.remove_tags( + experiment.project_name, + ["wow"], + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) - tags = repository.get_tags(experiment.project_name, experiment_id=experiment.id) + tags = repository.get_tags( + experiment.project_name, + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) assert {"added_tags": ["cool"]} in tags assert {"removed_tags": ["wow"]} in tags @@ -918,6 +949,10 @@ def test_get_tags_with_no_results(memory_repository): repository = memory_repository experiment = _create_experiment(repository) - tags = repository.get_tags(experiment.project_name, experiment_id=experiment.id) + tags = repository.get_tags( + experiment.project_name, + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) assert tags == []