Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Oct 17, 2024
1 parent 7c5eb4b commit c97703c
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 103 deletions.
6 changes: 5 additions & 1 deletion src/datasets/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ def _download_batched(
download_func = partial(self._download_single, download_config=download_config)

fs: fsspec.AbstractFileSystem
fs, path = url_to_fs(url_or_filenames[0], **download_config.storage_options)
path = str(url_or_filenames[0])
if is_relative_path(path):
# append the relative path to the base_path
path = url_or_path_join(self._base_path, path)
fs, path = url_to_fs(path, **download_config.storage_options)
size = 0
try:
size = fs.info(path).get("size", 0)
Expand Down
205 changes: 114 additions & 91 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,8 +1186,26 @@ def get_module(self) -> DatasetModule:
for config_name in exported_dataset_infos
}
)
parquet_commit_hash = (
HfApi(
endpoint=config.HF_ENDPOINT,
token=self.download_config.token,
library_name="datasets",
library_version=__version__,
user_agent=get_datasets_user_agent(self.download_config.user_agent),
)
.dataset_info(
self.name,
revision="refs/convert/parquet",
token=self.download_config.token,
timeout=100.0,
)
.sha
) # fix the revision in case there are new commits in the meantime
metadata_configs = MetadataConfigs._from_exported_parquet_files_and_dataset_infos(
commit_hash=self.commit_hash, exported_parquet_files=exported_parquet_files, dataset_infos=dataset_infos
parquet_commit_hash=parquet_commit_hash,
exported_parquet_files=exported_parquet_files,
dataset_infos=dataset_infos,
)
module_path, _ = _PACKAGED_DATASETS_MODULES["parquet"]
builder_configs, default_config_name = create_builder_configs_from_metadata_configs(
Expand Down Expand Up @@ -1335,7 +1353,7 @@ def get_module(self) -> DatasetModule:
# make the new module to be noticed by the import system
importlib.invalidate_caches()
builder_kwargs = {
"base_path": hf_dataset_url(self.name, "", revision=self.revision).rstrip("/"),
"base_path": hf_dataset_url(self.name, "", revision=self.commit_hash).rstrip("/"),
"repo_id": self.name,
}
return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path)
Expand Down Expand Up @@ -1565,97 +1583,102 @@ def dataset_module_factory(
).get_module()
# Try remotely
elif is_relative_path(path) and path.count("/") <= 1:
# Get the Dataset Card + get the revision + check authentication all at in one call
# We fix the commit_hash in case there are new commits in the meantime
api = HfApi(
endpoint=config.HF_ENDPOINT,
token=download_config.token,
library_name="datasets",
library_version=__version__,
user_agent=get_datasets_user_agent(download_config.user_agent),
)
try:
_raise_if_offline_mode_is_enabled()
dataset_readme_path = api.hf_hub_download(
repo_id=path,
filename=config.REPOCARD_FILENAME,
repo_type="dataset",
revision=revision,
proxies=download_config.proxies,
# Get the Dataset Card + get the revision + check authentication all at in one call
# We fix the commit_hash in case there are new commits in the meantime
api = HfApi(
endpoint=config.HF_ENDPOINT,
token=download_config.token,
library_name="datasets",
library_version=__version__,
user_agent=get_datasets_user_agent(download_config.user_agent),
)
commit_hash = os.path.basename(os.path.dirname(dataset_readme_path))
except EntryNotFoundError:
commit_hash = api.dataset_info(
path,
revision=revision,
timeout=100.0,
).sha
except (
OfflineModeIsEnabled,
requests.exceptions.ConnectTimeout,
requests.exceptions.ConnectionError,
) as e:
raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({e.__class__.__name__})") from e
except GatedRepoError as e:
message = f"Dataset '{path}' is a gated dataset on the Hub."
if "401 Client Error" in str(e):
message += " You must be authenticated to access it."
elif "403 Client Error" in str(e):
message += f" Visit the dataset page at https://huggingface.co/datasets/{path} to ask for access."
raise DatasetNotFoundError(message) from e
except RevisionNotFoundError as e:
raise DatasetNotFoundError(f"Revision '{revision}' doesn't exist for dataset '{path}' on the Hub.") from e
except RepositoryNotFoundError as e:
raise DatasetNotFoundError(f"Dataset '{path}' doesn't exist on the Hub or cannot be accessed.") from e
try:
dataset_script_path = api.hf_hub_download(
repo_id=path,
filename=filename,
repo_type="dataset",
revision=commit_hash,
proxies=download_config.proxies,
)
if _require_custom_configs or (revision and revision != "main"):
can_load_config_from_parquet_export = False
elif _require_default_config_name:
with open(dataset_script_path, "r", encoding="utf-8") as f:
can_load_config_from_parquet_export = "DEFAULT_CONFIG_NAME" not in f.read()
else:
can_load_config_from_parquet_export = True
if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export:
# If the parquet export is ready (parquet files + info available for the current sha), we can use it instead
# This fails when the dataset has multiple configs and a default config and
# the user didn't specify a configuration name (_require_default_config_name=True).
try:
return HubDatasetModuleFactoryWithParquetExport(
path, download_config=download_config, commit_hash=commit_hash
).get_module()
except _dataset_viewer.DatasetViewerError:
pass
# Otherwise we must use the dataset script if the user trusts it
return HubDatasetModuleFactoryWithScript(
path,
commit_hash=commit_hash,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
).get_module()
except EntryNotFoundError:
# Use the infos from the parquet export except in some cases:
if data_dir or data_files or (revision and revision != "main"):
use_exported_dataset_infos = False
else:
use_exported_dataset_infos = True
return HubDatasetModuleFactoryWithoutScript(
path,
commit_hash=commit_hash,
data_dir=data_dir,
data_files=data_files,
download_config=download_config,
download_mode=download_mode,
use_exported_dataset_infos=use_exported_dataset_infos,
).get_module()
try:
_raise_if_offline_mode_is_enabled()
dataset_readme_path = api.hf_hub_download(
repo_id=path,
filename=config.REPOCARD_FILENAME,
repo_type="dataset",
revision=revision,
proxies=download_config.proxies,
)
commit_hash = os.path.basename(os.path.dirname(dataset_readme_path))
except EntryNotFoundError:
commit_hash = api.dataset_info(
path,
revision=revision,
timeout=100.0,
).sha
except (
OfflineModeIsEnabled,
requests.exceptions.ConnectTimeout,
requests.exceptions.ConnectionError,
) as e:
raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({e.__class__.__name__})") from e
except GatedRepoError as e:
message = f"Dataset '{path}' is a gated dataset on the Hub."
if "401 Client Error" in str(e):
message += " You must be authenticated to access it."
elif "403 Client Error" in str(e):
message += f" Visit the dataset page at https://huggingface.co/datasets/{path} to ask for access."
raise DatasetNotFoundError(message) from e
except RevisionNotFoundError as e:
raise DatasetNotFoundError(
f"Revision '{revision}' doesn't exist for dataset '{path}' on the Hub."
) from e
except RepositoryNotFoundError as e:
raise DatasetNotFoundError(f"Dataset '{path}' doesn't exist on the Hub or cannot be accessed.") from e
try:
dataset_script_path = api.hf_hub_download(
repo_id=path,
filename=filename,
repo_type="dataset",
revision=commit_hash,
proxies=download_config.proxies,
)
if _require_custom_configs or (revision and revision != "main"):
can_load_config_from_parquet_export = False
elif _require_default_config_name:
with open(dataset_script_path, "r", encoding="utf-8") as f:
can_load_config_from_parquet_export = "DEFAULT_CONFIG_NAME" not in f.read()
else:
can_load_config_from_parquet_export = True
if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export:
# If the parquet export is ready (parquet files + info available for the current sha), we can use it instead
# This fails when the dataset has multiple configs and a default config and
# the user didn't specify a configuration name (_require_default_config_name=True).
try:
out = HubDatasetModuleFactoryWithParquetExport(
path, download_config=download_config, commit_hash=commit_hash
).get_module()
logger.info("Loading the dataset from the Parquet export on Hugging Face.")
return out
except _dataset_viewer.DatasetViewerError:
pass
# Otherwise we must use the dataset script if the user trusts it
return HubDatasetModuleFactoryWithScript(
path,
commit_hash=commit_hash,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
).get_module()
except EntryNotFoundError:
# Use the infos from the parquet export except in some cases:
if data_dir or data_files or (revision and revision != "main"):
use_exported_dataset_infos = False
else:
use_exported_dataset_infos = True
return HubDatasetModuleFactoryWithoutScript(
path,
commit_hash=commit_hash,
data_dir=data_dir,
data_files=data_files,
download_config=download_config,
download_mode=download_mode,
use_exported_dataset_infos=use_exported_dataset_infos,
).get_module()
except Exception as e1:
# All the attempts failed, before raising the error we should check if the module is already cached
try:
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/utils/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _raise_if_data_files_field_not_valid(metadata_config: dict):
@classmethod
def _from_exported_parquet_files_and_dataset_infos(
cls,
commit_hash: str,
parquet_commit_hash: str,
exported_parquet_files: List[Dict[str, Any]],
dataset_infos: DatasetInfosDict,
) -> "MetadataConfigs":
Expand All @@ -112,7 +112,7 @@ def _from_exported_parquet_files_and_dataset_infos(
{
"split": split_name,
"path": [
parquet_file["url"].replace("refs%2Fconvert%2Fparquet", commit_hash)
parquet_file["url"].replace("refs%2Fconvert%2Fparquet", parquet_commit_hash)
for parquet_file in parquet_files_for_split
],
}
Expand Down
21 changes: 12 additions & 9 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,19 +876,22 @@ def test_CachedDatasetModuleFactory_with_script(self):


@pytest.mark.parametrize(
"factory_class",
"factory_class,requires_commit_hash",
[
CachedDatasetModuleFactory,
HubDatasetModuleFactoryWithoutScript,
HubDatasetModuleFactoryWithScript,
LocalDatasetModuleFactoryWithoutScript,
LocalDatasetModuleFactoryWithScript,
PackagedDatasetModuleFactory,
(CachedDatasetModuleFactory, False),
(HubDatasetModuleFactoryWithoutScript, True),
(HubDatasetModuleFactoryWithScript, True),
(LocalDatasetModuleFactoryWithoutScript, False),
(LocalDatasetModuleFactoryWithScript, False),
(PackagedDatasetModuleFactory, False),
],
)
def test_module_factories(factory_class):
def test_module_factories(factory_class, requires_commit_hash):
name = "dummy_name"
factory = factory_class(name)
if requires_commit_hash:
factory = factory_class(name, commit_hash="foo")
else:
factory = factory_class(name)
assert factory.name == name


Expand Down

0 comments on commit c97703c

Please sign in to comment.