Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Oct 17, 2024
1 parent a3251f7 commit 7c5eb4b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ class HubDatasetModuleFactoryWithoutScript(_DatasetModuleFactory):
def __init__(
self,
name: str,
commit_hash: Optional[str] = None,
commit_hash: str,
data_dir: Optional[str] = None,
data_files: Optional[Union[str, List, Dict]] = None,
download_config: Optional[DownloadConfig] = None,
Expand Down Expand Up @@ -1165,7 +1165,7 @@ class HubDatasetModuleFactoryWithParquetExport(_DatasetModuleFactory):
def __init__(
self,
name: str,
commit_hash: Optional[str] = None,
commit_hash: str,
download_config: Optional[DownloadConfig] = None,
):
self.name = name
Expand Down Expand Up @@ -1223,7 +1223,7 @@ class HubDatasetModuleFactoryWithScript(_DatasetModuleFactory):
def __init__(
self,
name: str,
commit_hash: Optional[str] = None,
commit_hash: str,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[Union[DownloadMode, str]] = None,
dynamic_modules_path: Optional[str] = None,
Expand Down
58 changes: 44 additions & 14 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,13 @@ def _generate_examples(self, filepath, **kwargs):
SAMPLE_DATASET_IDENTIFIER3 = "hf-internal-testing/multi_dir_dataset" # has multiple data directories
SAMPLE_DATASET_IDENTIFIER4 = "hf-internal-testing/imagefolder_with_metadata" # imagefolder with a metadata file outside of the train/test directories
SAMPLE_DATASET_IDENTIFIER5 = "hf-internal-testing/imagefolder_with_metadata_no_splits" # imagefolder with a metadata file and no default split names in data files
SAMPLE_NOT_EXISTING_DATASET_IDENTIFIER = "hf-internal-testing/_dummy"
SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST = "_dummy"

SAMPLE_DATASET_COMMIT_HASH = "0e1cee81e718feadf49560b287c4eb669c2efb1a"
SAMPLE_DATASET_COMMIT_HASH2 = "c19550d35263090b1ec2bfefdbd737431fafec40"
SAMPLE_DATASET_COMMIT_HASH3 = "aaa2d4bdd1d877d1c6178562cfc584bdfa90f6dc"
SAMPLE_DATASET_COMMIT_HASH4 = "a7415617490f32e51c2f0ea20b5ce7cfba035a62"
SAMPLE_DATASET_COMMIT_HASH5 = "4971fa562942cab8263f56a448c3f831b18f1c27"

SAMPLE_DATASET_NO_CONFIGS_IN_METADATA = "hf-internal-testing/audiofolder_no_configs_in_metadata"
SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_single_config_in_metadata"
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_two_configs_in_metadata"
Expand All @@ -100,6 +105,15 @@ def _generate_examples(self, filepath, **kwargs):
)
SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME = "hf-internal-testing/DatasetWithCapitalLetters"

SAMPLE_DATASET_NO_CONFIGS_IN_METADATA_COMMIT_HASH = "26cd5079bb0d3cd1521c6894765a0b8edb159d7f"
SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA_COMMIT_HASH = "1668dfc91efae975e44457cdabef60fb9200820a"
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_COMMIT_HASH = "e71bce498e6c2bd2c58b20b097fdd3389793263f"
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_WITH_DEFAULT_COMMIT_HASH = "38937109bb4dc7067f575fe6e7b420158eb9cf32"
SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME_COMMIT_HASH = "70aa36264a6954920a13dd0465156a60b9f8af4b"

SAMPLE_NOT_EXISTING_DATASET_IDENTIFIER = "hf-internal-testing/_dummy"
SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST = "_dummy"


@pytest.fixture
def data_dir(tmp_path):
Expand Down Expand Up @@ -388,14 +402,16 @@ def setUp(self):

def test_HubDatasetModuleFactoryWithScript_dont_trust_remote_code(self):
factory = HubDatasetModuleFactoryWithScript(
"hf-internal-testing/dataset_with_script",
SAMPLE_DATASET_IDENTIFIER,
commit_hash=SAMPLE_DATASET_COMMIT_HASH,
download_config=self.download_config,
dynamic_modules_path=self.dynamic_modules_path,
)
with patch.object(config, "HF_DATASETS_TRUST_REMOTE_CODE", None): # this will be the default soon
self.assertRaises(ValueError, factory.get_module)
factory = HubDatasetModuleFactoryWithScript(
"hf-internal-testing/dataset_with_script",
SAMPLE_DATASET_IDENTIFIER,
commit_hash=SAMPLE_DATASET_COMMIT_HASH,
download_config=self.download_config,
dynamic_modules_path=self.dynamic_modules_path,
trust_remote_code=False,
Expand All @@ -406,9 +422,9 @@ def test_HubDatasetModuleFactoryWithScript_with_hub_dataset(self):
# "wmt_t2t" has additional imports (internal)
factory = HubDatasetModuleFactoryWithScript(
"wmt_t2t",
commit_hash="861aac88b2c6247dd93ade8b1c189ce714627750",
download_config=self.download_config,
dynamic_modules_path=self.dynamic_modules_path,
revision="861aac88b2c6247dd93ade8b1c189ce714627750",
trust_remote_code=True,
)
module_factory_result = factory.get_module()
Expand Down Expand Up @@ -616,7 +632,7 @@ def test_PackagedDatasetModuleFactory_with_data_dir_and_metadata(self):
@pytest.mark.integration
def test_HubDatasetModuleFactoryWithoutScript(self):
factory = HubDatasetModuleFactoryWithoutScript(
SAMPLE_DATASET_IDENTIFIER2, download_config=self.download_config
SAMPLE_DATASET_IDENTIFIER2, commit_hash=SAMPLE_DATASET_COMMIT_HASH2, download_config=self.download_config
)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None
Expand All @@ -626,7 +642,10 @@ def test_HubDatasetModuleFactoryWithoutScript(self):
def test_HubDatasetModuleFactoryWithoutScript_with_data_dir(self):
data_dir = "data2"
factory = HubDatasetModuleFactoryWithoutScript(
SAMPLE_DATASET_IDENTIFIER3, data_dir=data_dir, download_config=self.download_config
SAMPLE_DATASET_IDENTIFIER3,
commit_hash=SAMPLE_DATASET_COMMIT_HASH3,
data_dir=data_dir,
download_config=self.download_config,
)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None
Expand All @@ -645,7 +664,7 @@ def test_HubDatasetModuleFactoryWithoutScript_with_data_dir(self):
@pytest.mark.integration
def test_HubDatasetModuleFactoryWithoutScript_with_metadata(self):
factory = HubDatasetModuleFactoryWithoutScript(
SAMPLE_DATASET_IDENTIFIER4, download_config=self.download_config
SAMPLE_DATASET_IDENTIFIER4, commit_hash=SAMPLE_DATASET_COMMIT_HASH4, download_config=self.download_config
)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None
Expand All @@ -660,7 +679,7 @@ def test_HubDatasetModuleFactoryWithoutScript_with_metadata(self):
assert any(Path(data_file).name == "metadata.jsonl" for data_file in builder_config.data_files["test"])

factory = HubDatasetModuleFactoryWithoutScript(
SAMPLE_DATASET_IDENTIFIER5, download_config=self.download_config
SAMPLE_DATASET_IDENTIFIER5, commit_hash=SAMPLE_DATASET_COMMIT_HASH5, download_config=self.download_config
)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None
Expand All @@ -677,6 +696,7 @@ def test_HubDatasetModuleFactoryWithoutScript_with_metadata(self):
def test_HubDatasetModuleFactoryWithoutScript_with_one_default_config_in_metadata(self):
factory = HubDatasetModuleFactoryWithoutScript(
SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA,
commit_hash=SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA_COMMIT_HASH,
download_config=self.download_config,
)
module_factory_result = factory.get_module()
Expand Down Expand Up @@ -714,9 +734,17 @@ def test_HubDatasetModuleFactoryWithoutScript_with_one_default_config_in_metadat

@pytest.mark.integration
def test_HubDatasetModuleFactoryWithoutScript_with_two_configs_in_metadata(self):
datasets_names = [SAMPLE_DATASET_TWO_CONFIG_IN_METADATA, SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_WITH_DEFAULT]
for dataset_name in datasets_names:
factory = HubDatasetModuleFactoryWithoutScript(dataset_name, download_config=self.download_config)
datasets_names = [
(SAMPLE_DATASET_TWO_CONFIG_IN_METADATA, SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_COMMIT_HASH),
(
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_WITH_DEFAULT,
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_WITH_DEFAULT_COMMIT_HASH,
),
]
for dataset_name, commit_hash in datasets_names:
factory = HubDatasetModuleFactoryWithoutScript(
dataset_name, commit_hash=commit_hash, download_config=self.download_config
)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None

Expand Down Expand Up @@ -767,6 +795,7 @@ def test_HubDatasetModuleFactoryWithoutScript_with_two_configs_in_metadata(self)
def test_HubDatasetModuleFactoryWithScript(self):
factory = HubDatasetModuleFactoryWithScript(
SAMPLE_DATASET_IDENTIFIER,
commit_hash=SAMPLE_DATASET_COMMIT_HASH,
download_config=self.download_config,
dynamic_modules_path=self.dynamic_modules_path,
trust_remote_code=True,
Expand All @@ -779,6 +808,7 @@ def test_HubDatasetModuleFactoryWithScript(self):
def test_HubDatasetModuleFactoryWithParquetExport(self):
factory = HubDatasetModuleFactoryWithParquetExport(
SAMPLE_DATASET_IDENTIFIER,
commit_hash=SAMPLE_DATASET_COMMIT_HASH,
download_config=self.download_config,
)
module_factory_result = factory.get_module()
Expand All @@ -802,13 +832,13 @@ def test_HubDatasetModuleFactoryWithParquetExport_errors_on_wrong_sha(self):
factory = HubDatasetModuleFactoryWithParquetExport(
SAMPLE_DATASET_IDENTIFIER,
download_config=self.download_config,
revision="0e1cee81e718feadf49560b287c4eb669c2efb1a",
commit_hash=SAMPLE_DATASET_COMMIT_HASH,
)
factory.get_module()
factory = HubDatasetModuleFactoryWithParquetExport(
SAMPLE_DATASET_IDENTIFIER,
download_config=self.download_config,
revision="wrong_sha",
commit_hash="wrong_sha",
)
with self.assertRaises(_dataset_viewer.DatasetViewerError):
factory.get_module()
Expand Down

0 comments on commit 7c5eb4b

Please sign in to comment.