Skip to content

Commit

Permalink
pass s3 storage options to dataframe read/write (#484)
Browse files Browse the repository at this point in the history
* pass storage options to remote dataframes
* more tests & comments
  • Loading branch information
ryanSoley authored Sep 20, 2024
1 parent 32d1a83 commit 5054eb1
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 7 deletions.
10 changes: 6 additions & 4 deletions rubicon_ml/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class BaseRepository:
"""

def __init__(self, root_dir: str, **storage_options):
self._df_storage_options = {} # should only be non-empty for S3 logging

self.filesystem = fsspec.filesystem(self.PROTOCOL, **storage_options)
self.root_dir = root_dir.rstrip("/")

Expand Down Expand Up @@ -614,7 +616,7 @@ def _persist_dataframe(
df.write_parquet(path)
else:
# Dask or pandas
df.to_parquet(path, engine="pyarrow")
df.to_parquet(path, engine="pyarrow", storage_options=self._df_storage_options)

def _read_dataframe(self, path, df_type: Literal["pandas", "dask", "polars"] = "pandas"):
"""Reads the dataframe `df` from the configured filesystem."""
Expand All @@ -623,7 +625,7 @@ def _read_dataframe(self, path, df_type: Literal["pandas", "dask", "polars"] = "

if df_type == "pandas":
path = f"{path}/data.parquet"
df = pd.read_parquet(path, engine="pyarrow")
df = pd.read_parquet(path, engine="pyarrow", storage_options=self._df_storage_options)
elif df_type == "polars":
try:
from polars import read_parquet
Expand All @@ -633,7 +635,7 @@ def _read_dataframe(self, path, df_type: Literal["pandas", "dask", "polars"] = "
"to read dataframes with `df_type`='polars'. `pip install polars` "
"or `conda install polars` to continue."
)
df = read_parquet(path)
df = read_parquet(path, storage_options=self._df_storage_options)

elif df_type == "dask":
try:
Expand All @@ -645,7 +647,7 @@ def _read_dataframe(self, path, df_type: Literal["pandas", "dask", "polars"] = "
"or `conda install dask` to continue."
)

df = dd.read_parquet(path, engine="pyarrow")
df = dd.read_parquet(path, engine="pyarrow", storage_options=self._df_storage_options)
else:
raise ValueError(f"`df_type` must be one of {acceptable_types}")

Expand Down
2 changes: 2 additions & 0 deletions rubicon_ml/repository/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class MemoryRepository(LocalRepository):
PROTOCOL = "memory"

def __init__(self, root_dir=None, **storage_options):
self._df_storage_options = {} # should only be non-empty for S3 logging

self.filesystem = fsspec.filesystem(self.PROTOCOL, **storage_options)
self.root_dir = root_dir.rstrip("/") if root_dir is not None else "/root"

Expand Down
8 changes: 8 additions & 0 deletions rubicon_ml/repository/s3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import fsspec

from rubicon_ml.repository import BaseRepository
from rubicon_ml.repository.utils import json

Expand All @@ -18,6 +20,12 @@ class S3Repository(BaseRepository):

PROTOCOL = "s3"

def __init__(self, root_dir: str, **storage_options):
self._df_storage_options = storage_options

self.filesystem = fsspec.filesystem(self.PROTOCOL, **storage_options)
self.root_dir = root_dir.rstrip("/")

def _persist_bytes(self, bytes_data, path):
"""Persists the raw bytes `bytes_data` to the S3
bucket defined by `path`.
Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def rubicon_local_filesystem_client():
rubicon = Rubicon(
persistence="filesystem",
root_dir=os.path.join(os.path.dirname(os.path.realpath(__file__)), "rubicon"),
storage_option_a="test", # should be ignored when logging local dfs
)

# teardown after yield
Expand Down Expand Up @@ -221,7 +222,7 @@ def test_dataframe():
def memory_repository():
"""Setup an in-memory repository and clean it up afterwards."""
root_dir = "/in-memory-root"
repository = MemoryRepository(root_dir)
repository = MemoryRepository(root_dir, storage_option_a="test")

yield repository
repository.filesystem.rm(root_dir, recursive=True)
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/repository/test_base_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,11 @@ def test_persist_dataframe(mock_to_parquet, memory_repository):
# calls `BaseRepository._persist_dataframe` despite class using `MemoryRepository`
super(MemoryRepository, repository)._persist_dataframe(df, path)

mock_to_parquet.assert_called_once_with(f"{path}/data.parquet", engine="pyarrow")
mock_to_parquet.assert_called_once_with(
f"{path}/data.parquet",
engine="pyarrow",
storage_options={},
)


@patch("polars.DataFrame.write_parquet")
Expand All @@ -426,7 +430,11 @@ def test_read_dataframe(mock_read_parquet, memory_repository):
# calls `BaseRepository._read_dataframe` despite class using `MemoryRepository`
super(MemoryRepository, repository)._read_dataframe(path)

mock_read_parquet.assert_called_once_with(f"{path}/data.parquet", engine="pyarrow")
mock_read_parquet.assert_called_once_with(
f"{path}/data.parquet",
engine="pyarrow",
storage_options={},
)


def test_read_dataframe_value_error(memory_repository):
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/repository/test_s3_repo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from unittest.mock import patch

import pandas as pd
import pytest
import s3fs

Expand Down Expand Up @@ -49,3 +50,30 @@ def test_persist_domain_throws_error(mock_open):
s3_repo._persist_domain(project, project_metadata_path)

mock_open.assert_not_called()


@patch("s3fs.core.S3FileSystem.mkdirs")
@patch("pandas.DataFrame.to_parquet")
def test_persist_dataframe(mock_to_parquet, mock_mkdirs):
s3_repo = S3Repository(root_dir="s3://bucket/root", storage_option_a="test")
df = pd.DataFrame([[0, 1], [1, 0]], columns=["a", "b"])

s3_repo._persist_dataframe(df, s3_repo.root_dir)

mock_to_parquet.assert_called_once_with(
f"{s3_repo.root_dir}/data.parquet",
engine="pyarrow",
storage_options={"storage_option_a": "test"},
)


@patch("pandas.read_parquet")
def test_read_dataframe(mock_read_parquet):
s3_repo = S3Repository(root_dir="s3://bucket/root", storage_option_a="test")
s3_repo._read_dataframe(s3_repo.root_dir)

mock_read_parquet.assert_called_once_with(
f"{s3_repo.root_dir}/data.parquet",
engine="pyarrow",
storage_options={"storage_option_a": "test"},
)

0 comments on commit 5054eb1

Please sign in to comment.