From 91b07b90915d7f7313d44ca3ff67673b9ad26bf4 Mon Sep 17 00:00:00 2001 From: Xin Zhang Date: Mon, 15 Apr 2024 23:38:50 +0800 Subject: [PATCH] Fix cache path to snakecase for `CachedDatasetModuleFactory` and `Cache` (#6754) * fix cache path snakecase * Apply suggestions from code review Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * fix variable names of suggestions * add test capital letters dataset to `test_load.py` and `test_cache.py` --------- Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/load.py | 5 ++++- src/datasets/packaged_modules/cache/cache.py | 7 ++++-- tests/packaged_modules/test_cache.py | 23 ++++++++++++++++++++ tests/test_load.py | 14 ++++++++++++ 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index 73068e153d2..6e3ebdcb1b7 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1608,7 +1608,10 @@ def _get_modification_time(module_hash): } return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path) cache_dir = os.path.expanduser(str(self.cache_dir or config.HF_DATASETS_CACHE)) - cached_datasets_directory_path_root = os.path.join(cache_dir, self.name.replace("/", "___")) + namespace_and_dataset_name = self.name.split("/") + namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1]) + cached_relative_path = "___".join(namespace_and_dataset_name) + cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path) cached_directory_paths = [ cached_directory_path for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*")) diff --git a/src/datasets/packaged_modules/cache/cache.py b/src/datasets/packaged_modules/cache/cache.py index 2f31176e08b..9085b22078b 100644 --- a/src/datasets/packaged_modules/cache/cache.py +++ b/src/datasets/packaged_modules/cache/cache.py @@ -12,7 +12,7 @@ import datasets import datasets.config import datasets.data_files -from datasets.naming import filenames_for_dataset_split +from datasets.naming import camelcase_to_snakecase, filenames_for_dataset_split logger = datasets.utils.logging.get_logger(__name__) @@ -36,7 +36,10 @@ def _find_hash_in_cache( else: config_id = None cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE)) - cached_datasets_directory_path_root = os.path.join(cache_dir, dataset_name.replace("/", "___")) + namespace_and_dataset_name = dataset_name.split("/") + namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1]) + cached_relative_path = "___".join(namespace_and_dataset_name) + cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path) cached_directory_paths = [ cached_directory_path for cached_directory_path in glob.glob( diff --git a/tests/packaged_modules/test_cache.py b/tests/packaged_modules/test_cache.py index fdde27dbc1c..cfb947777e4 100644 --- a/tests/packaged_modules/test_cache.py +++ b/tests/packaged_modules/test_cache.py @@ -8,6 +8,7 @@ SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_single_config_in_metadata" SAMPLE_DATASET_TWO_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_two_configs_in_metadata" +SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME = "hf-internal-testing/DatasetWithCapitalLetters" def test_cache(text_dir: Path, tmp_path: Path): @@ -133,3 +134,25 @@ def test_cache_single_config(tmp_path: Path): hash="auto", ) assert config_name in str(excinfo.value) + + +@pytest.mark.integration +def test_cache_capital_letters(tmp_path: Path): + cache_dir = tmp_path / "test_cache_capital_letters" + repo_id = SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME + dataset_name = repo_id.split("/")[-1] + ds = load_dataset(repo_id, cache_dir=str(cache_dir)) + cache = Cache(cache_dir=str(cache_dir), dataset_name=dataset_name, repo_id=repo_id, version="auto", hash="auto") + reloaded = cache.as_dataset() + assert list(ds) == list(reloaded) + assert len(ds["train"]) == len(reloaded["train"]) + cache = Cache( + cache_dir=str(cache_dir), + dataset_name=dataset_name, + repo_id=repo_id, + version="auto", + hash="auto", + ) + reloaded = cache.as_dataset() + assert list(ds) == list(reloaded) + assert len(ds["train"]) == len(reloaded["train"]) diff --git a/tests/test_load.py b/tests/test_load.py index c59ce7d5e6c..4e86efaf012 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -100,6 +100,7 @@ def _generate_examples(self, filepath, **kwargs): SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_WITH_DEFAULT = ( "hf-internal-testing/audiofolder_two_configs_in_metadata_with_default" ) +SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME = "hf-internal-testing/DatasetWithCapitalLetters" METRIC_LOADING_SCRIPT_NAME = "__dummy_metric1__" @@ -1026,6 +1027,19 @@ def test_offline_dataset_module_factory_with_script(self): self.assertNotEqual(dataset_module_1.module_path, dataset_module_3.module_path) self.assertIn("Using the latest cached version of the module", self._caplog.text) + @pytest.mark.integration + def test_offline_dataset_module_factory_with_capital_letters_in_name(self): + repo_id = SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME + builder = load_dataset_builder(repo_id, cache_dir=self.cache_dir) + builder.download_and_prepare() + for offline_simulation_mode in list(OfflineSimulationMode): + with offline(offline_simulation_mode): + self._caplog.clear() + # allow provide the repo id without an explicit path to remote or local actual file + dataset_module = datasets.load.dataset_module_factory(repo_id, cache_dir=self.cache_dir) + self.assertEqual(dataset_module.module_path, "datasets.packaged_modules.cache.cache") + self.assertIn("Using the latest cached version of the dataset", self._caplog.text) + def test_load_dataset_from_hub(self): with self.assertRaises(DatasetNotFoundError) as context: datasets.load_dataset("_dummy")