From a72a0986f771e3ed59e5e47dd2dc588ed6f68264 Mon Sep 17 00:00:00 2001 From: Ryan Soley Date: Wed, 13 Mar 2024 13:32:47 -0400 Subject: [PATCH] log zipped directories as artifacts (#418) * log directory as artifact * temporarily download artifacts * fix existing tests * add tests --- rubicon_ml/client/artifact.py | 55 +++++++++++++++- rubicon_ml/client/mixin.py | 76 ++++++++++++++++------- tests/unit/client/test_artifact_client.py | 41 ++++++++++++ tests/unit/client/test_mixin_client.py | 34 +++++++++- 4 files changed, 181 insertions(+), 25 deletions(-) diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index 21709ca7..238fac21 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -1,9 +1,12 @@ from __future__ import annotations +import contextlib import json import os import pickle +import tempfile import warnings +import zipfile from typing import TYPE_CHECKING, Literal, Optional import fsspec @@ -112,7 +115,12 @@ def get_json(self): return json.loads(self.get_data()) @failsafe - def download(self, location: Optional[str] = None, name: Optional[str] = None): + def download( + self, + location: Optional[str] = None, + name: Optional[str] = None, + unzip: bool = False, + ): """Download this artifact's data. Parameters @@ -125,6 +133,9 @@ def download(self, location: Optional[str] = None, name: Optional[str] = None): name : str, optional The name to give the downloaded artifact file. Defaults to the artifact's given name when logged. + unzip : bool, optional + True to unzip the artifact data. False otherwise. + Defaults to False. """ if location is None: location = os.getcwd() @@ -132,8 +143,46 @@ def download(self, location: Optional[str] = None, name: Optional[str] = None): if name is None: name = self._domain.name - with fsspec.open(os.path.join(location, name), "wb", auto_mkdir=False) as f: - f.write(self.data) + if unzip: + temp_file_context = tempfile.TemporaryDirectory + else: + temp_file_context = contextlib.nullcontext + + with temp_file_context() as temp_dir: + if unzip: + location_path = os.path.join(temp_dir, "temp_file.zip") + else: + location_path = os.path.join(location, name) + + with fsspec.open(location_path, "wb", auto_mkdir=False) as f: + f.write(self.data) + + if unzip: + with zipfile.ZipFile(location_path, "r") as zip_file: + zip_file.extractall(location) + + @contextlib.contextmanager + @failsafe + def temporary_download(self, unzip: bool = False): + """Temporarily download this artifact's data within a context manager. + + Parameters + ---------- + unzip : bool, optional + True to unzip the artifact data. False otherwise. + Defaults to False. + + Yields + ------ + file + An open file pointer into the directory the artifact data was + temporarily downloaded into. If the artifact is a single file, + its name is stored in the `artifact.name` attribute. + """ + with tempfile.TemporaryDirectory() as temp_dir: + self.download(location=temp_dir, unzip=unzip) + + yield temp_dir @property def id(self) -> str: diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index fc66f128..797e6e3c 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -1,12 +1,15 @@ from __future__ import annotations +import contextlib import json import os import pickle import subprocess import tempfile import warnings +import zipfile from datetime import datetime +from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, TextIO, Union import fsspec @@ -30,13 +33,12 @@ class ArtifactMixin: _domain: ArtifactDomain - def _validate_data(self, data_bytes, data_file, data_object, data_path, name): - """Raises a `RubiconException` if the data to log as - an artifact is improperly provided. - """ - if not any([data_bytes, data_file, data_object, data_path]): + def _validate_data(self, data_bytes, data_directory, data_file, data_object, data_path, name): + """Raises a `RubiconException` if the data to log as an artifact is improperly provided.""" + if not any([data_bytes, data_directory, data_file, data_object, data_path]): raise RubiconException( - "One of `data_bytes`, `data_file`, `data_object` or `data_path` must be provided." + "One of `data_bytes`, `data_directory`, `data_file`, `data_object` or " + "`data_path` must be provided." ) if name is None: @@ -45,17 +47,32 @@ def _validate_data(self, data_bytes, data_file, data_object, data_path, name): else: raise RubiconException("`name` must be provided if not using `data_path`.") + if data_directory is not None: + temp_file_context = tempfile.TemporaryDirectory + else: + temp_file_context = contextlib.nullcontext + if data_bytes is None: - if data_object is not None: - data_bytes = pickle.dumps(data_object) - else: - if data_file is not None: - f = data_file - elif data_path is not None: - f = fsspec.open(data_path, "rb") + with temp_file_context() as temp_dir: + if data_object is not None: + data_bytes = pickle.dumps(data_object) + else: + if data_directory is not None: + temp_zip_name = Path(f"{temp_dir}/{name}") + + with zipfile.ZipFile(str(temp_zip_name), "w") as zip_file: + for dir_path, _, files in os.walk(data_directory): + for file in files: + zip_file.write(Path(f"{dir_path}/{file}"), arcname=file) - with f as open_file: - data_bytes = open_file.read() + file = fsspec.open(temp_zip_name, "rb") + elif data_file is not None: + file = data_file + elif data_path is not None: + file = fsspec.open(data_path, "rb") + + with file as open_file: + data_bytes = open_file.read() return data_bytes, name @@ -63,6 +80,7 @@ def _validate_data(self, data_bytes, data_file, data_object, data_path, name): def log_artifact( self, data_bytes: Optional[bytes] = None, + data_directory: Optional[str] = None, data_file: Optional[TextIO] = None, data_object: Optional[Any] = None, data_path: Optional[str] = None, @@ -77,6 +95,8 @@ def log_artifact( ---------- data_bytes : bytes, optional The raw bytes to log as an artifact. + data_directory : str, optional + The path to a directory to zip and log as an artifact. data_file : TextIOWrapper, optional The open file to log as an artifact. data_object : python object, optional @@ -113,18 +133,30 @@ def log_artifact( -------- >>> # Log with bytes >>> experiment.log_artifact( - ... data_bytes=b'hello rubicon!', name='bytes_artifact', description="log artifact from bytes" + ... data_bytes=b'hello rubicon!', + ... name="bytes_artifact", + ... description="log artifact from bytes", + ... ) + + >>> # Log zipped directory + >>> experiment.log_artifact( + ... data_directory="./path/to/directory/", + ... name="directory.zip", + ... description="log artifact from zipped directory", ... ) >>> # Log with file - >>> with open('some_relevant_file', 'rb') as f: + >>> with open('./path/to/artifact.txt', 'rb') as file: >>> project.log_artifact( - ... data_file=f, name='file_artifact', description="log artifact from file" - ... ) + ... data_file=file, + ... name="file_artifact", + ... description="log artifact from file", + ... ) >>> # Log with file path >>> experiment.log_artifact( - ... data_path="./path/to/artifact.pkl", description="log artifact from file path" + ... data_path="./path/to/artifact.pkl", + ... description="log artifact from file path", ... ) """ if tags is None: @@ -139,7 +171,9 @@ def log_artifact( ): raise ValueError("`comments` must be `list` of type `str`") - data_bytes, name = self._validate_data(data_bytes, data_file, data_object, data_path, name) + data_bytes, name = self._validate_data( + data_bytes, data_directory, data_file, data_object, data_path, name + ) artifact = domain.Artifact( name=name, diff --git a/tests/unit/client/test_artifact_client.py b/tests/unit/client/test_artifact_client.py index 713785bc..9e2d4614 100644 --- a/tests/unit/client/test_artifact_client.py +++ b/tests/unit/client/test_artifact_client.py @@ -1,4 +1,6 @@ import os +import tempfile +from pathlib import Path from unittest.mock import MagicMock, patch import h2o @@ -164,3 +166,42 @@ def test_get_data_deserialize_h2o( artifact_data = artifact.get_data(deserialize="h2o") assert artifact_data.__class__ == h2o_model.__class__ + + +def test_download_data_unzip(project_client): + """Test downloading and unzipping artifact data.""" + project = project_client + + with tempfile.TemporaryDirectory() as temp_dir: + with open(Path(temp_dir, "test_file_a.txt"), "w") as file: + file.write("testing rubicon") + + with open(Path(temp_dir, "test_file_b.txt"), "w") as file: + file.write("testing rubicon again") + + artifact = project.log_artifact( + data_directory=temp_dir, + name="test.zip", + tags=["x"], + comments=["this is a comment"], + ) + + Path(temp_dir, "test_file_a.txt").unlink() + Path(temp_dir, "test_file_b.txt").unlink() + + artifact.download(location=temp_dir, unzip=True) + + assert Path(temp_dir, "test_file_a.txt").exists() + assert Path(temp_dir, "test_file_b.txt").exists() + + +def test_temporary_download(project_client): + """Test temporarily downloading artifact data.""" + project = project_client + data = b"content" + artifact = project.log_artifact(name="test.txt", data_bytes=data) + + with artifact.temporary_download() as temp_artifact_dir: + assert Path(temp_artifact_dir, "test.txt").exists() + + assert not Path(temp_artifact_dir, "test.txt").exists() diff --git a/tests/unit/client/test_mixin_client.py b/tests/unit/client/test_mixin_client.py index 1f40bd54..fdcddcca 100644 --- a/tests/unit/client/test_mixin_client.py +++ b/tests/unit/client/test_mixin_client.py @@ -1,5 +1,7 @@ import subprocess +import tempfile import warnings +from pathlib import Path from unittest import mock from unittest.mock import MagicMock, patch @@ -36,6 +38,35 @@ def test_log_artifact_from_bytes(project_client): assert artifact.comments == ["this is a comment"] +def test_log_artifact_from_directory(project_client): + """Test logging artifacts from a directory.""" + project = project_client + + with tempfile.TemporaryDirectory() as temp_dir: + with open(Path(temp_dir, "test_file_a.txt"), "w") as file: + file.write("testing rubicon") + + with open(Path(temp_dir, "test_file_b.txt"), "w") as file: + file.write("testing rubicon again") + + artifact = ArtifactMixin.log_artifact( + project, + data_directory=temp_dir, + name="test.zip", + tags=["x"], + comments=["this is a comment"], + ) + + assert artifact.id in [a.id for a in project.artifacts()] + assert artifact.name == "test.zip" + assert artifact.tags == ["x"] + assert artifact.comments == ["this is a comment"] + + with artifact.temporary_download(unzip=True) as temp_artifact_dir: + assert Path(temp_artifact_dir, "test_file_a.txt").exists() + assert Path(temp_artifact_dir, "test_file_b.txt").exists() + + def test_log_artifact_from_file(project_client): project = project_client mock_file = MagicMock() @@ -74,7 +105,8 @@ def test_log_artifact_throws_error_if_data_missing(project_client): ArtifactMixin.log_artifact(project, name="test.txt") assert ( - "One of `data_bytes`, `data_file`, `data_object` or `data_path` must be provided." in str(e) + "One of `data_bytes`, `data_directory`, `data_file`, `data_object` " + "or `data_path` must be provided." in str(e) )