Skip to content

Commit

Permalink
Fix cache path to snakecase for CachedDatasetModuleFactory and `Cac…
Browse files Browse the repository at this point in the history
…he` (#6754)

* fix cache path snakecase

* Apply suggestions from code review

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>

* fix variable names of suggestions

* add test capital letters dataset to `test_load.py` and `test_cache.py`

---------

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
  • Loading branch information
izhx and lhoestq authored Apr 15, 2024
1 parent a3bc89d commit 91b07b9
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,7 +1608,10 @@ def _get_modification_time(module_hash):
}
return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path)
cache_dir = os.path.expanduser(str(self.cache_dir or config.HF_DATASETS_CACHE))
cached_datasets_directory_path_root = os.path.join(cache_dir, self.name.replace("/", "___"))
namespace_and_dataset_name = self.name.split("/")
namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1])
cached_relative_path = "___".join(namespace_and_dataset_name)
cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path)
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*"))
Expand Down
7 changes: 5 additions & 2 deletions src/datasets/packaged_modules/cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import datasets
import datasets.config
import datasets.data_files
from datasets.naming import filenames_for_dataset_split
from datasets.naming import camelcase_to_snakecase, filenames_for_dataset_split


logger = datasets.utils.logging.get_logger(__name__)
Expand All @@ -36,7 +36,10 @@ def _find_hash_in_cache(
else:
config_id = None
cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE))
cached_datasets_directory_path_root = os.path.join(cache_dir, dataset_name.replace("/", "___"))
namespace_and_dataset_name = dataset_name.split("/")
namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1])
cached_relative_path = "___".join(namespace_and_dataset_name)
cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path)
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(
Expand Down
23 changes: 23 additions & 0 deletions tests/packaged_modules/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_single_config_in_metadata"
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_two_configs_in_metadata"
SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME = "hf-internal-testing/DatasetWithCapitalLetters"


def test_cache(text_dir: Path, tmp_path: Path):
Expand Down Expand Up @@ -133,3 +134,25 @@ def test_cache_single_config(tmp_path: Path):
hash="auto",
)
assert config_name in str(excinfo.value)


@pytest.mark.integration
def test_cache_capital_letters(tmp_path: Path):
cache_dir = tmp_path / "test_cache_capital_letters"
repo_id = SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME
dataset_name = repo_id.split("/")[-1]
ds = load_dataset(repo_id, cache_dir=str(cache_dir))
cache = Cache(cache_dir=str(cache_dir), dataset_name=dataset_name, repo_id=repo_id, version="auto", hash="auto")
reloaded = cache.as_dataset()
assert list(ds) == list(reloaded)
assert len(ds["train"]) == len(reloaded["train"])
cache = Cache(
cache_dir=str(cache_dir),
dataset_name=dataset_name,
repo_id=repo_id,
version="auto",
hash="auto",
)
reloaded = cache.as_dataset()
assert list(ds) == list(reloaded)
assert len(ds["train"]) == len(reloaded["train"])
14 changes: 14 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _generate_examples(self, filepath, **kwargs):
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_WITH_DEFAULT = (
"hf-internal-testing/audiofolder_two_configs_in_metadata_with_default"
)
SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME = "hf-internal-testing/DatasetWithCapitalLetters"


METRIC_LOADING_SCRIPT_NAME = "__dummy_metric1__"
Expand Down Expand Up @@ -1026,6 +1027,19 @@ def test_offline_dataset_module_factory_with_script(self):
self.assertNotEqual(dataset_module_1.module_path, dataset_module_3.module_path)
self.assertIn("Using the latest cached version of the module", self._caplog.text)

@pytest.mark.integration
def test_offline_dataset_module_factory_with_capital_letters_in_name(self):
repo_id = SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME
builder = load_dataset_builder(repo_id, cache_dir=self.cache_dir)
builder.download_and_prepare()
for offline_simulation_mode in list(OfflineSimulationMode):
with offline(offline_simulation_mode):
self._caplog.clear()
# allow provide the repo id without an explicit path to remote or local actual file
dataset_module = datasets.load.dataset_module_factory(repo_id, cache_dir=self.cache_dir)
self.assertEqual(dataset_module.module_path, "datasets.packaged_modules.cache.cache")
self.assertIn("Using the latest cached version of the dataset", self._caplog.text)

def test_load_dataset_from_hub(self):
with self.assertRaises(DatasetNotFoundError) as context:
datasets.load_dataset("_dummy")
Expand Down

0 comments on commit 91b07b9

Please sign in to comment.