From d690c4f6e241d5bad312672cb018bfbf022f189e Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue, 30 Jul 2024 10:21:50 +0200 Subject: [PATCH] Set load_from_disk path type as PathLike (#7081) * Set path type as PathLike in load_from_disk * Update docstrings * Update tests * Update save_to_disk docstrings --- src/datasets/arrow_dataset.py | 6 +++--- src/datasets/dataset_dict.py | 9 ++++----- src/datasets/load.py | 12 ++++++++---- tests/test_arrow_dataset.py | 2 +- tests/test_dataset_dict.py | 2 +- tests/test_load.py | 2 +- 6 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 5ac987e3f9b..7ba052d3fde 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1460,7 +1460,7 @@ def save_to_disk( If you want to store paths or urls, please use the Value("string") type. Args: - dataset_path (`str`): + dataset_path (`path-like`): Path (e.g. `dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`) of the dataset directory where the dataset will be saved to. fs (`fsspec.spec.AbstractFileSystem`, *optional*): @@ -1660,7 +1660,7 @@ def _build_local_temp_path(uri_or_path: str) -> Path: @staticmethod def load_from_disk( - dataset_path: str, + dataset_path: PathLike, fs="deprecated", keep_in_memory: Optional[bool] = None, storage_options: Optional[dict] = None, @@ -1670,7 +1670,7 @@ def load_from_disk( filesystem using any implementation of `fsspec.spec.AbstractFileSystem`. Args: - dataset_path (`str`): + dataset_path (`path-like`): Path (e.g. `"dataset/train"`) or remote URI (e.g. `"s3//my-bucket/dataset/train"`) of the dataset directory where the dataset will be loaded from. fs (`fsspec.spec.AbstractFileSystem`, *optional*): diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 9e3e8543b77..5d2d9dcd9ff 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1231,10 +1231,9 @@ def save_to_disk( If you want to store paths or urls, please use the Value("string") type. Args: - dataset_dict_path (`str`): - Path (e.g. `dataset/train`) or remote URI - (e.g. `s3://my-bucket/dataset/train`) of the dataset dict directory where the dataset dict will be - saved to. + dataset_dict_path (`path-like`): + Path (e.g. `dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`) + of the dataset dict directory where the dataset dict will be saved to. fs (`fsspec.spec.AbstractFileSystem`, *optional*): Instance of the remote filesystem where the dataset will be saved to. @@ -1314,7 +1313,7 @@ def load_from_disk( Load a dataset that was previously saved using [`save_to_disk`] from a filesystem using `fsspec.spec.AbstractFileSystem`. Args: - dataset_dict_path (`str`): + dataset_dict_path (`path-like`): Path (e.g. `"dataset/train"`) or remote URI (e.g. `"s3//my-bucket/dataset/train"`) of the dataset dict directory where the dataset dict will be loaded from. fs (`fsspec.spec.AbstractFileSystem`, *optional*): diff --git a/src/datasets/load.py b/src/datasets/load.py index c6e3791a53e..4c330482255 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -90,6 +90,7 @@ from .utils.logging import get_logger from .utils.metadata import MetadataConfigs from .utils.py_utils import get_imports, lock_importable_file +from .utils.typing import PathLike from .utils.version import Version @@ -2648,16 +2649,19 @@ def load_dataset( def load_from_disk( - dataset_path: str, fs="deprecated", keep_in_memory: Optional[bool] = None, storage_options: Optional[dict] = None + dataset_path: PathLike, + fs="deprecated", + keep_in_memory: Optional[bool] = None, + storage_options: Optional[dict] = None, ) -> Union[Dataset, DatasetDict]: """ Loads a dataset that was previously saved using [`~Dataset.save_to_disk`] from a dataset directory, or from a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`. Args: - dataset_path (`str`): - Path (e.g. `"dataset/train"`) or remote URI (e.g. - `"s3://my-bucket/dataset/train"`) of the [`Dataset`] or [`DatasetDict`] directory where the dataset will be + dataset_path (`path-like`): + Path (e.g. `"dataset/train"`) or remote URI (e.g. `"s3://my-bucket/dataset/train"`) + of the [`Dataset`] or [`DatasetDict`] directory where the dataset/dataset-dict will be loaded from. fs (`~filesystems.S3FileSystem` or `fsspec.spec.AbstractFileSystem`, *optional*): Instance of the remote filesystem used to download the files from. diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 01bb71024dc..f747ef7a980 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -4119,7 +4119,7 @@ def test_dummy_dataset_serialize_fs(dataset, mockfs): dataset.save_to_disk(dataset_path, storage_options=mockfs.storage_options) assert mockfs.isdir(dataset_path) assert mockfs.glob(dataset_path + "/*") - reloaded = dataset.load_from_disk(dataset_path, storage_options=mockfs.storage_options) + reloaded = Dataset.load_from_disk(dataset_path, storage_options=mockfs.storage_options) assert len(reloaded) == len(dataset) assert reloaded.features == dataset.features assert reloaded.to_dict() == dataset.to_dict() diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index 75419c55778..83f0395718b 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -596,7 +596,7 @@ def test_dummy_datasetdict_serialize_fs(mockfs): dataset_dict.save_to_disk(dataset_path, storage_options=mockfs.storage_options) assert mockfs.isdir(dataset_path) assert mockfs.glob(dataset_path + "/*") - reloaded = dataset_dict.load_from_disk(dataset_path, storage_options=mockfs.storage_options) + reloaded = DatasetDict.load_from_disk(dataset_path, storage_options=mockfs.storage_options) assert list(reloaded) == list(dataset_dict) for k in dataset_dict: assert reloaded[k].features == dataset_dict[k].features diff --git a/tests/test_load.py b/tests/test_load.py index 38d6852fcf5..7a6fc1baabc 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1680,7 +1680,7 @@ def test_load_from_disk_with_default_in_memory( expected_in_memory = False dset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, keep_in_memory=True, trust_remote_code=True) - dataset_path = os.path.join(tmp_path, "saved_dataset") + dataset_path = tmp_path / "saved_dataset" dset.save_to_disk(dataset_path) with assert_arrow_memory_increases() if expected_in_memory else assert_arrow_memory_doesnt_increase():