Skip to content

Commit

Permalink
add tagging to artifacts (#268)
Browse files Browse the repository at this point in the history
* add tags to artifact domain
* add tagging functionality to artifacts
* pass tests
* remove `MultiParentMixin`
* `_get_parent_identifiers` -> `_get_identifiers`
* update docstrings
  • Loading branch information
ryanSoley authored Sep 15, 2022
1 parent 2d403d6 commit 0633fd1
Show file tree
Hide file tree
Showing 15 changed files with 253 additions and 125 deletions.
7 changes: 4 additions & 3 deletions rubicon_ml/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import fsspec

from rubicon_ml.client import Base
from rubicon_ml.client.base import Base
from rubicon_ml.client.mixin import TagMixin


class Artifact(Base):
class Artifact(Base, TagMixin):
"""A client artifact.
An `artifact` is a catch-all for any other type of
Expand Down Expand Up @@ -35,7 +36,7 @@ def __init__(self, domain, parent):

def _get_data(self):
"""Loads the data associated with this artifact."""
project_name, experiment_id = self.parent._get_parent_identifiers()
project_name, experiment_id = self.parent._get_identifiers()

self._data = self.repository.get_artifact_data(
project_name, self.id, experiment_id=experiment_id
Expand Down
2 changes: 1 addition & 1 deletion rubicon_ml/client/asynchronous/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def _get_data(self):
"""Overrides `rubicon.client.Artifact._get_data` to
asynchronously load the data associated with this artifact.
"""
project_name, experiment_id = self.parent._get_parent_identifiers()
project_name, experiment_id = self.parent._get_identifiers()

self._data = await self.repository.get_artifact_data(
project_name, self.id, experiment_id=experiment_id
Expand Down
17 changes: 8 additions & 9 deletions rubicon_ml/client/asynchronous/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

from rubicon_ml import domain
from rubicon_ml.client import asynchronous as client
from rubicon_ml.client.mixin import MultiParentMixin
from rubicon_ml.client.utils.tags import has_tag_requirements


class ArtifactMixin(MultiParentMixin):
class ArtifactMixin:
"""Adds artifact support to an asynchronous client object."""

async def log_artifact(
Expand Down Expand Up @@ -66,7 +65,7 @@ async def log_artifact(

artifact = domain.Artifact(name=name, description=description, parent_id=self._domain.id)

project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()
await self.repository.create_artifact(
artifact, data_bytes, project_name, experiment_id=experiment_id
)
Expand Down Expand Up @@ -130,7 +129,7 @@ async def artifacts(self):
list of rubicon.client.Artifact
The artifacts previously logged to this client object.
"""
project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()

self._artifacts = [
client.Artifact(a, self)
Expand All @@ -151,7 +150,7 @@ async def delete_artifacts(self, ids):
ids : list of str
The ids of the artifacts to delete.
"""
project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()

await asyncio.gather(
*[
Expand All @@ -163,7 +162,7 @@ async def delete_artifacts(self, ids):
)


class DataframeMixin(MultiParentMixin):
class DataframeMixin:
"""Adds dataframe support to an asynchronous client object."""

async def log_dataframe(self, df, description=None, tags=[]):
Expand All @@ -187,7 +186,7 @@ async def log_dataframe(self, df, description=None, tags=[]):
"""
dataframe = domain.Dataframe(parent_id=self._domain.id, description=description, tags=tags)

project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()
await self.repository.create_dataframe(
dataframe, df, project_name, experiment_id=experiment_id
)
Expand Down Expand Up @@ -226,7 +225,7 @@ async def dataframes(self, tags=[], qtype="or"):
list of rubicon.client.Dataframe
The dataframes previously logged to this client object.
"""
project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()
dataframes = [
client.Dataframe(d, self)
for d in await self.repository.get_dataframes_metadata(
Expand All @@ -248,7 +247,7 @@ async def delete_dataframes(self, ids):
ids : list of str
The ids of the dataframes to delete.
"""
project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()

await asyncio.gather(
*[
Expand Down
2 changes: 1 addition & 1 deletion rubicon_ml/client/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_data(self, df_type="pandas"):
The type of dataframe. Can be either `pandas` or `dask`.
Defaults to 'pandas'.
"""
project_name, experiment_id = self.parent._get_parent_identifiers()
project_name, experiment_id = self.parent._get_identifiers()

self._data = self.repository.get_dataframe_data(
project_name,
Expand Down
4 changes: 4 additions & 0 deletions rubicon_ml/client/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def __init__(self, domain, parent):
self._features = []
self._parameters = []

def _get_identifiers(self):
"""Get the experiment's project's name and the experiment's ID."""
return self.project.name, self.id

def log_metric(self, name, value, directionality="score", description=None):
"""Create a metric under the experiment.
Expand Down
92 changes: 48 additions & 44 deletions rubicon_ml/client/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,7 @@
from rubicon_ml.exceptions import RubiconException


class MultiParentMixin:
"""Adds utils for client objects that can be logged
to either a `Project` or `Experiment`.
"""

def _get_parent_identifiers(self):
"""Get the project name and experiment ID (or
`None`) of this client object's parent(s).
"""
experiment_id = None

if isinstance(self, client.Project):
project_name = self.name
else:
project_name = self.project.name
experiment_id = self.id

return project_name, experiment_id


class ArtifactMixin(MultiParentMixin):
class ArtifactMixin:
"""Adds artifact support to a client object."""

def _validate_data(self, data_bytes, data_file, data_path, name):
Expand Down Expand Up @@ -60,7 +40,13 @@ def _validate_data(self, data_bytes, data_file, data_path, name):
return data_bytes, name

def log_artifact(
self, data_bytes=None, data_file=None, data_path=None, name=None, description=None
self,
data_bytes=None,
data_file=None,
data_path=None,
name=None,
description=None,
tags=[],
):
"""Log an artifact to this client object.
Expand All @@ -80,6 +66,9 @@ def log_artifact(
description : str, optional
A description of the artifact. Use to provide
additional context.
tags : list of str, optional
Values to tag the experiment with. Use tags to organize and
filter your artifacts.
Notes
-----
Expand Down Expand Up @@ -112,9 +101,14 @@ def log_artifact(
"""
data_bytes, name = self._validate_data(data_bytes, data_file, data_path, name)

artifact = domain.Artifact(name=name, description=description, parent_id=self._domain.id)
artifact = domain.Artifact(
name=name,
description=description,
parent_id=self._domain.id,
tags=tags,
)

project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()
self.repository.create_artifact(
artifact, data_bytes, project_name, experiment_id=experiment_id
)
Expand Down Expand Up @@ -202,7 +196,7 @@ def artifacts(self, name=None):
list of rubicon.client.Artifact
The artifacts previously logged to this client object.
"""
project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()
self._artifacts = [
client.Artifact(a, self)
for a in self.repository.get_artifacts_metadata(
Expand Down Expand Up @@ -245,7 +239,7 @@ def artifact(self, name=None, id=None):

artifact = artifacts[-1]
else:
project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()
artifact = client.Artifact(
self.repository.get_artifact_metadata(project_name, id, experiment_id), self
)
Expand All @@ -261,13 +255,13 @@ def delete_artifacts(self, ids):
ids : list of str
The ids of the artifacts to delete.
"""
project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()

for artifact_id in ids:
self.repository.delete_artifact(project_name, artifact_id, experiment_id=experiment_id)


class DataframeMixin(MultiParentMixin):
class DataframeMixin:
"""Adds dataframe support to a client object."""

def log_dataframe(self, df, description=None, name=None, tags=[]):
Expand Down Expand Up @@ -295,7 +289,7 @@ def log_dataframe(self, df, description=None, name=None, tags=[]):
tags=tags,
)

project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()
self.repository.create_dataframe(dataframe, df, project_name, experiment_id=experiment_id)

return client.Dataframe(dataframe, self)
Expand Down Expand Up @@ -336,7 +330,7 @@ def dataframes(self, name=None, tags=[], qtype="or"):
list of rubicon.client.Dataframe
The dataframes previously logged to this client object.
"""
project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()
dataframes = [
client.Dataframe(d, self)
for d in self.repository.get_dataframes_metadata(
Expand Down Expand Up @@ -379,7 +373,7 @@ def dataframe(self, name=None, id=None):

dataframe = dataframes[-1]
else:
project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()
dataframe = client.Dataframe(
self.repository.get_dataframe_metadata(
project_name, experiment_id=experiment_id, dataframe_id=id
Expand All @@ -398,7 +392,7 @@ def delete_dataframes(self, ids):
ids : list of str
The ids of the dataframes to delete.
"""
project_name, experiment_id = self._get_parent_identifiers()
project_name, experiment_id = self._get_identifiers()

for dataframe_id in ids:
self.repository.delete_dataframe(
Expand All @@ -410,15 +404,16 @@ class TagMixin:
"""Adds tag support to a client object."""

def _get_taggable_identifiers(self):
dataframe_id = None
project_name, experiment_id = self._parent._get_identifiers()
entity_id = None

if isinstance(self, client.Dataframe):
project_name, experiment_id = self.parent._get_parent_identifiers()
dataframe_id = self.id
# experiments are not required to return an entity ID - they are the entity
if isinstance(self, client.Experiment):
experiment_id = self.id
else:
project_name, experiment_id = self._get_parent_identifiers()
entity_id = self.id

return project_name, experiment_id, dataframe_id
return project_name, experiment_id, entity_id

def add_tags(self, tags):
"""Add tags to this client object.
Expand All @@ -428,11 +423,15 @@ def add_tags(self, tags):
tags : list of str
The tag values to add.
"""
project_name, experiment_id, dataframe_id = self._get_taggable_identifiers()
project_name, experiment_id, entity_id = self._get_taggable_identifiers()

self._domain.add_tags(tags)
self.repository.add_tags(
project_name, tags, experiment_id=experiment_id, dataframe_id=dataframe_id
project_name,
tags,
experiment_id=experiment_id,
entity_id=entity_id,
entity_type=self.__class__.__name__,
)

def remove_tags(self, tags):
Expand All @@ -443,11 +442,15 @@ def remove_tags(self, tags):
tags : list of str
The tag values to remove.
"""
project_name, experiment_id, dataframe_id = self._get_taggable_identifiers()
project_name, experiment_id, entity_id = self._get_taggable_identifiers()

self._domain.remove_tags(tags)
self.repository.remove_tags(
project_name, tags, experiment_id=experiment_id, dataframe_id=dataframe_id
project_name,
tags,
experiment_id=experiment_id,
entity_id=entity_id,
entity_type=self.__class__.__name__,
)

def _update_tags(self, tag_data):
Expand All @@ -461,11 +464,12 @@ def _update_tags(self, tag_data):
@property
def tags(self):
"""Get this client object's tags."""
project_name, experiment_id, dataframe_id = self._get_taggable_identifiers()
project_name, experiment_id, entity_id = self._get_taggable_identifiers()
tag_data = self.repository.get_tags(
project_name,
experiment_id=experiment_id,
dataframe_id=dataframe_id,
entity_id=entity_id,
entity_type=self.__class__.__name__,
)

self._update_tags(tag_data)
Expand Down
4 changes: 4 additions & 0 deletions rubicon_ml/client/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def _get_commit_hash(self):

return completed_process.stdout.decode("utf8").replace("\n", "")

def _get_identifiers(self):
"""Get the project's name."""
return self.name, None

def _create_experiment_domain(
self,
name,
Expand Down
5 changes: 4 additions & 1 deletion rubicon_ml/domain/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

from dataclasses import dataclass, field
from datetime import datetime
from typing import List

from rubicon_ml.domain.mixin import TagMixin
from rubicon_ml.domain.utils import uuid


@dataclass
class Artifact:
class Artifact(TagMixin):
name: str

id: str = field(default_factory=uuid.uuid4)
description: str = None
created_at: datetime = field(default_factory=datetime.utcnow)
tags: List[str] = field(default_factory=list)

parent_id: str = None
Loading

0 comments on commit 0633fd1

Please sign in to comment.