Skip to content

Commit

Permalink
log zipped directories as artifacts (#418)
Browse files Browse the repository at this point in the history
* log directory as artifact
* temporarily download artifacts
* fix existing tests
* add tests
  • Loading branch information
ryanSoley authored Mar 13, 2024
1 parent 8badb1a commit a72a098
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 25 deletions.
55 changes: 52 additions & 3 deletions rubicon_ml/client/artifact.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import contextlib
import json
import os
import pickle
import tempfile
import warnings
import zipfile
from typing import TYPE_CHECKING, Literal, Optional

import fsspec
Expand Down Expand Up @@ -112,7 +115,12 @@ def get_json(self):
return json.loads(self.get_data())

@failsafe
def download(self, location: Optional[str] = None, name: Optional[str] = None):
def download(
self,
location: Optional[str] = None,
name: Optional[str] = None,
unzip: bool = False,
):
"""Download this artifact's data.
Parameters
Expand All @@ -125,15 +133,56 @@ def download(self, location: Optional[str] = None, name: Optional[str] = None):
name : str, optional
The name to give the downloaded artifact file.
Defaults to the artifact's given name when logged.
unzip : bool, optional
True to unzip the artifact data. False otherwise.
Defaults to False.
"""
if location is None:
location = os.getcwd()

if name is None:
name = self._domain.name

with fsspec.open(os.path.join(location, name), "wb", auto_mkdir=False) as f:
f.write(self.data)
if unzip:
temp_file_context = tempfile.TemporaryDirectory
else:
temp_file_context = contextlib.nullcontext

with temp_file_context() as temp_dir:
if unzip:
location_path = os.path.join(temp_dir, "temp_file.zip")
else:
location_path = os.path.join(location, name)

with fsspec.open(location_path, "wb", auto_mkdir=False) as f:
f.write(self.data)

if unzip:
with zipfile.ZipFile(location_path, "r") as zip_file:
zip_file.extractall(location)

@contextlib.contextmanager
@failsafe
def temporary_download(self, unzip: bool = False):
"""Temporarily download this artifact's data within a context manager.
Parameters
----------
unzip : bool, optional
True to unzip the artifact data. False otherwise.
Defaults to False.
Yields
------
file
An open file pointer into the directory the artifact data was
temporarily downloaded into. If the artifact is a single file,
its name is stored in the `artifact.name` attribute.
"""
with tempfile.TemporaryDirectory() as temp_dir:
self.download(location=temp_dir, unzip=unzip)

yield temp_dir

@property
def id(self) -> str:
Expand Down
76 changes: 55 additions & 21 deletions rubicon_ml/client/mixin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import contextlib
import json
import os
import pickle
import subprocess
import tempfile
import warnings
import zipfile
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TextIO, Union

import fsspec
Expand All @@ -30,13 +33,12 @@ class ArtifactMixin:

_domain: ArtifactDomain

def _validate_data(self, data_bytes, data_file, data_object, data_path, name):
"""Raises a `RubiconException` if the data to log as
an artifact is improperly provided.
"""
if not any([data_bytes, data_file, data_object, data_path]):
def _validate_data(self, data_bytes, data_directory, data_file, data_object, data_path, name):
"""Raises a `RubiconException` if the data to log as an artifact is improperly provided."""
if not any([data_bytes, data_directory, data_file, data_object, data_path]):
raise RubiconException(
"One of `data_bytes`, `data_file`, `data_object` or `data_path` must be provided."
"One of `data_bytes`, `data_directory`, `data_file`, `data_object` or "
"`data_path` must be provided."
)

if name is None:
Expand All @@ -45,24 +47,40 @@ def _validate_data(self, data_bytes, data_file, data_object, data_path, name):
else:
raise RubiconException("`name` must be provided if not using `data_path`.")

if data_directory is not None:
temp_file_context = tempfile.TemporaryDirectory
else:
temp_file_context = contextlib.nullcontext

if data_bytes is None:
if data_object is not None:
data_bytes = pickle.dumps(data_object)
else:
if data_file is not None:
f = data_file
elif data_path is not None:
f = fsspec.open(data_path, "rb")
with temp_file_context() as temp_dir:
if data_object is not None:
data_bytes = pickle.dumps(data_object)
else:
if data_directory is not None:
temp_zip_name = Path(f"{temp_dir}/{name}")

with zipfile.ZipFile(str(temp_zip_name), "w") as zip_file:
for dir_path, _, files in os.walk(data_directory):
for file in files:
zip_file.write(Path(f"{dir_path}/{file}"), arcname=file)

with f as open_file:
data_bytes = open_file.read()
file = fsspec.open(temp_zip_name, "rb")
elif data_file is not None:
file = data_file
elif data_path is not None:
file = fsspec.open(data_path, "rb")

with file as open_file:
data_bytes = open_file.read()

return data_bytes, name

@failsafe
def log_artifact(
self,
data_bytes: Optional[bytes] = None,
data_directory: Optional[str] = None,
data_file: Optional[TextIO] = None,
data_object: Optional[Any] = None,
data_path: Optional[str] = None,
Expand All @@ -77,6 +95,8 @@ def log_artifact(
----------
data_bytes : bytes, optional
The raw bytes to log as an artifact.
data_directory : str, optional
The path to a directory to zip and log as an artifact.
data_file : TextIOWrapper, optional
The open file to log as an artifact.
data_object : python object, optional
Expand Down Expand Up @@ -113,18 +133,30 @@ def log_artifact(
--------
>>> # Log with bytes
>>> experiment.log_artifact(
... data_bytes=b'hello rubicon!', name='bytes_artifact', description="log artifact from bytes"
... data_bytes=b'hello rubicon!',
... name="bytes_artifact",
... description="log artifact from bytes",
... )
>>> # Log zipped directory
>>> experiment.log_artifact(
... data_directory="./path/to/directory/",
... name="directory.zip",
... description="log artifact from zipped directory",
... )
>>> # Log with file
>>> with open('some_relevant_file', 'rb') as f:
>>> with open('./path/to/artifact.txt', 'rb') as file:
>>> project.log_artifact(
... data_file=f, name='file_artifact', description="log artifact from file"
... )
... data_file=file,
... name="file_artifact",
... description="log artifact from file",
... )
>>> # Log with file path
>>> experiment.log_artifact(
... data_path="./path/to/artifact.pkl", description="log artifact from file path"
... data_path="./path/to/artifact.pkl",
... description="log artifact from file path",
... )
"""
if tags is None:
Expand All @@ -139,7 +171,9 @@ def log_artifact(
):
raise ValueError("`comments` must be `list` of type `str`")

data_bytes, name = self._validate_data(data_bytes, data_file, data_object, data_path, name)
data_bytes, name = self._validate_data(
data_bytes, data_directory, data_file, data_object, data_path, name
)

artifact = domain.Artifact(
name=name,
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/client/test_artifact_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch

import h2o
Expand Down Expand Up @@ -164,3 +166,42 @@ def test_get_data_deserialize_h2o(
artifact_data = artifact.get_data(deserialize="h2o")

assert artifact_data.__class__ == h2o_model.__class__


def test_download_data_unzip(project_client):
"""Test downloading and unzipping artifact data."""
project = project_client

with tempfile.TemporaryDirectory() as temp_dir:
with open(Path(temp_dir, "test_file_a.txt"), "w") as file:
file.write("testing rubicon")

with open(Path(temp_dir, "test_file_b.txt"), "w") as file:
file.write("testing rubicon again")

artifact = project.log_artifact(
data_directory=temp_dir,
name="test.zip",
tags=["x"],
comments=["this is a comment"],
)

Path(temp_dir, "test_file_a.txt").unlink()
Path(temp_dir, "test_file_b.txt").unlink()

artifact.download(location=temp_dir, unzip=True)

assert Path(temp_dir, "test_file_a.txt").exists()
assert Path(temp_dir, "test_file_b.txt").exists()


def test_temporary_download(project_client):
"""Test temporarily downloading artifact data."""
project = project_client
data = b"content"
artifact = project.log_artifact(name="test.txt", data_bytes=data)

with artifact.temporary_download() as temp_artifact_dir:
assert Path(temp_artifact_dir, "test.txt").exists()

assert not Path(temp_artifact_dir, "test.txt").exists()
34 changes: 33 additions & 1 deletion tests/unit/client/test_mixin_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import subprocess
import tempfile
import warnings
from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -36,6 +38,35 @@ def test_log_artifact_from_bytes(project_client):
assert artifact.comments == ["this is a comment"]


def test_log_artifact_from_directory(project_client):
"""Test logging artifacts from a directory."""
project = project_client

with tempfile.TemporaryDirectory() as temp_dir:
with open(Path(temp_dir, "test_file_a.txt"), "w") as file:
file.write("testing rubicon")

with open(Path(temp_dir, "test_file_b.txt"), "w") as file:
file.write("testing rubicon again")

artifact = ArtifactMixin.log_artifact(
project,
data_directory=temp_dir,
name="test.zip",
tags=["x"],
comments=["this is a comment"],
)

assert artifact.id in [a.id for a in project.artifacts()]
assert artifact.name == "test.zip"
assert artifact.tags == ["x"]
assert artifact.comments == ["this is a comment"]

with artifact.temporary_download(unzip=True) as temp_artifact_dir:
assert Path(temp_artifact_dir, "test_file_a.txt").exists()
assert Path(temp_artifact_dir, "test_file_b.txt").exists()


def test_log_artifact_from_file(project_client):
project = project_client
mock_file = MagicMock()
Expand Down Expand Up @@ -74,7 +105,8 @@ def test_log_artifact_throws_error_if_data_missing(project_client):
ArtifactMixin.log_artifact(project, name="test.txt")

assert (
"One of `data_bytes`, `data_file`, `data_object` or `data_path` must be provided." in str(e)
"One of `data_bytes`, `data_directory`, `data_file`, `data_object` "
"or `data_path` must be provided." in str(e)
)


Expand Down

0 comments on commit a72a098

Please sign in to comment.