diff --git a/src/datasets/download/download_manager.py b/src/datasets/download/download_manager.py index 6ccb4f9d1c9..a8dff37aeef 100644 --- a/src/datasets/download/download_manager.py +++ b/src/datasets/download/download_manager.py @@ -189,7 +189,11 @@ def _download_batched( download_func = partial(self._download_single, download_config=download_config) fs: fsspec.AbstractFileSystem - fs, path = url_to_fs(url_or_filenames[0], **download_config.storage_options) + path = str(url_or_filenames[0]) + if is_relative_path(path): + # append the relative path to the base_path + path = url_or_path_join(self._base_path, path) + fs, path = url_to_fs(path, **download_config.storage_options) size = 0 try: size = fs.info(path).get("size", 0) diff --git a/src/datasets/load.py b/src/datasets/load.py index 795ee655adf..19811c8733f 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1186,8 +1186,26 @@ def get_module(self) -> DatasetModule: for config_name in exported_dataset_infos } ) + parquet_commit_hash = ( + HfApi( + endpoint=config.HF_ENDPOINT, + token=self.download_config.token, + library_name="datasets", + library_version=__version__, + user_agent=get_datasets_user_agent(self.download_config.user_agent), + ) + .dataset_info( + self.name, + revision="refs/convert/parquet", + token=self.download_config.token, + timeout=100.0, + ) + .sha + ) # fix the revision in case there are new commits in the meantime metadata_configs = MetadataConfigs._from_exported_parquet_files_and_dataset_infos( - commit_hash=self.commit_hash, exported_parquet_files=exported_parquet_files, dataset_infos=dataset_infos + parquet_commit_hash=parquet_commit_hash, + exported_parquet_files=exported_parquet_files, + dataset_infos=dataset_infos, ) module_path, _ = _PACKAGED_DATASETS_MODULES["parquet"] builder_configs, default_config_name = create_builder_configs_from_metadata_configs( @@ -1335,7 +1353,7 @@ def get_module(self) -> DatasetModule: # make the new module to be noticed by the import system importlib.invalidate_caches() builder_kwargs = { - "base_path": hf_dataset_url(self.name, "", revision=self.revision).rstrip("/"), + "base_path": hf_dataset_url(self.name, "", revision=self.commit_hash).rstrip("/"), "repo_id": self.name, } return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path) @@ -1565,97 +1583,102 @@ def dataset_module_factory( ).get_module() # Try remotely elif is_relative_path(path) and path.count("/") <= 1: - # Get the Dataset Card + get the revision + check authentication all at in one call - # We fix the commit_hash in case there are new commits in the meantime - api = HfApi( - endpoint=config.HF_ENDPOINT, - token=download_config.token, - library_name="datasets", - library_version=__version__, - user_agent=get_datasets_user_agent(download_config.user_agent), - ) try: - _raise_if_offline_mode_is_enabled() - dataset_readme_path = api.hf_hub_download( - repo_id=path, - filename=config.REPOCARD_FILENAME, - repo_type="dataset", - revision=revision, - proxies=download_config.proxies, + # Get the Dataset Card + get the revision + check authentication all at in one call + # We fix the commit_hash in case there are new commits in the meantime + api = HfApi( + endpoint=config.HF_ENDPOINT, + token=download_config.token, + library_name="datasets", + library_version=__version__, + user_agent=get_datasets_user_agent(download_config.user_agent), ) - commit_hash = os.path.basename(os.path.dirname(dataset_readme_path)) - except EntryNotFoundError: - commit_hash = api.dataset_info( - path, - revision=revision, - timeout=100.0, - ).sha - except ( - OfflineModeIsEnabled, - requests.exceptions.ConnectTimeout, - requests.exceptions.ConnectionError, - ) as e: - raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({e.__class__.__name__})") from e - except GatedRepoError as e: - message = f"Dataset '{path}' is a gated dataset on the Hub." - if "401 Client Error" in str(e): - message += " You must be authenticated to access it." - elif "403 Client Error" in str(e): - message += f" Visit the dataset page at https://huggingface.co/datasets/{path} to ask for access." - raise DatasetNotFoundError(message) from e - except RevisionNotFoundError as e: - raise DatasetNotFoundError(f"Revision '{revision}' doesn't exist for dataset '{path}' on the Hub.") from e - except RepositoryNotFoundError as e: - raise DatasetNotFoundError(f"Dataset '{path}' doesn't exist on the Hub or cannot be accessed.") from e - try: - dataset_script_path = api.hf_hub_download( - repo_id=path, - filename=filename, - repo_type="dataset", - revision=commit_hash, - proxies=download_config.proxies, - ) - if _require_custom_configs or (revision and revision != "main"): - can_load_config_from_parquet_export = False - elif _require_default_config_name: - with open(dataset_script_path, "r", encoding="utf-8") as f: - can_load_config_from_parquet_export = "DEFAULT_CONFIG_NAME" not in f.read() - else: - can_load_config_from_parquet_export = True - if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export: - # If the parquet export is ready (parquet files + info available for the current sha), we can use it instead - # This fails when the dataset has multiple configs and a default config and - # the user didn't specify a configuration name (_require_default_config_name=True). - try: - return HubDatasetModuleFactoryWithParquetExport( - path, download_config=download_config, commit_hash=commit_hash - ).get_module() - except _dataset_viewer.DatasetViewerError: - pass - # Otherwise we must use the dataset script if the user trusts it - return HubDatasetModuleFactoryWithScript( - path, - commit_hash=commit_hash, - download_config=download_config, - download_mode=download_mode, - dynamic_modules_path=dynamic_modules_path, - trust_remote_code=trust_remote_code, - ).get_module() - except EntryNotFoundError: - # Use the infos from the parquet export except in some cases: - if data_dir or data_files or (revision and revision != "main"): - use_exported_dataset_infos = False - else: - use_exported_dataset_infos = True - return HubDatasetModuleFactoryWithoutScript( - path, - commit_hash=commit_hash, - data_dir=data_dir, - data_files=data_files, - download_config=download_config, - download_mode=download_mode, - use_exported_dataset_infos=use_exported_dataset_infos, - ).get_module() + try: + _raise_if_offline_mode_is_enabled() + dataset_readme_path = api.hf_hub_download( + repo_id=path, + filename=config.REPOCARD_FILENAME, + repo_type="dataset", + revision=revision, + proxies=download_config.proxies, + ) + commit_hash = os.path.basename(os.path.dirname(dataset_readme_path)) + except EntryNotFoundError: + commit_hash = api.dataset_info( + path, + revision=revision, + timeout=100.0, + ).sha + except ( + OfflineModeIsEnabled, + requests.exceptions.ConnectTimeout, + requests.exceptions.ConnectionError, + ) as e: + raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({e.__class__.__name__})") from e + except GatedRepoError as e: + message = f"Dataset '{path}' is a gated dataset on the Hub." + if "401 Client Error" in str(e): + message += " You must be authenticated to access it." + elif "403 Client Error" in str(e): + message += f" Visit the dataset page at https://huggingface.co/datasets/{path} to ask for access." + raise DatasetNotFoundError(message) from e + except RevisionNotFoundError as e: + raise DatasetNotFoundError( + f"Revision '{revision}' doesn't exist for dataset '{path}' on the Hub." + ) from e + except RepositoryNotFoundError as e: + raise DatasetNotFoundError(f"Dataset '{path}' doesn't exist on the Hub or cannot be accessed.") from e + try: + dataset_script_path = api.hf_hub_download( + repo_id=path, + filename=filename, + repo_type="dataset", + revision=commit_hash, + proxies=download_config.proxies, + ) + if _require_custom_configs or (revision and revision != "main"): + can_load_config_from_parquet_export = False + elif _require_default_config_name: + with open(dataset_script_path, "r", encoding="utf-8") as f: + can_load_config_from_parquet_export = "DEFAULT_CONFIG_NAME" not in f.read() + else: + can_load_config_from_parquet_export = True + if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export: + # If the parquet export is ready (parquet files + info available for the current sha), we can use it instead + # This fails when the dataset has multiple configs and a default config and + # the user didn't specify a configuration name (_require_default_config_name=True). + try: + out = HubDatasetModuleFactoryWithParquetExport( + path, download_config=download_config, commit_hash=commit_hash + ).get_module() + logger.info("Loading the dataset from the Parquet export on Hugging Face.") + return out + except _dataset_viewer.DatasetViewerError: + pass + # Otherwise we must use the dataset script if the user trusts it + return HubDatasetModuleFactoryWithScript( + path, + commit_hash=commit_hash, + download_config=download_config, + download_mode=download_mode, + dynamic_modules_path=dynamic_modules_path, + trust_remote_code=trust_remote_code, + ).get_module() + except EntryNotFoundError: + # Use the infos from the parquet export except in some cases: + if data_dir or data_files or (revision and revision != "main"): + use_exported_dataset_infos = False + else: + use_exported_dataset_infos = True + return HubDatasetModuleFactoryWithoutScript( + path, + commit_hash=commit_hash, + data_dir=data_dir, + data_files=data_files, + download_config=download_config, + download_mode=download_mode, + use_exported_dataset_infos=use_exported_dataset_infos, + ).get_module() except Exception as e1: # All the attempts failed, before raising the error we should check if the module is already cached try: diff --git a/src/datasets/utils/metadata.py b/src/datasets/utils/metadata.py index 7e639f10d18..21629407e4c 100644 --- a/src/datasets/utils/metadata.py +++ b/src/datasets/utils/metadata.py @@ -102,7 +102,7 @@ def _raise_if_data_files_field_not_valid(metadata_config: dict): @classmethod def _from_exported_parquet_files_and_dataset_infos( cls, - commit_hash: str, + parquet_commit_hash: str, exported_parquet_files: List[Dict[str, Any]], dataset_infos: DatasetInfosDict, ) -> "MetadataConfigs": @@ -112,7 +112,7 @@ def _from_exported_parquet_files_and_dataset_infos( { "split": split_name, "path": [ - parquet_file["url"].replace("refs%2Fconvert%2Fparquet", commit_hash) + parquet_file["url"].replace("refs%2Fconvert%2Fparquet", parquet_commit_hash) for parquet_file in parquet_files_for_split ], } diff --git a/tests/test_load.py b/tests/test_load.py index 652bb9b07ad..5d551e8afbe 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -876,19 +876,22 @@ def test_CachedDatasetModuleFactory_with_script(self): @pytest.mark.parametrize( - "factory_class", + "factory_class,requires_commit_hash", [ - CachedDatasetModuleFactory, - HubDatasetModuleFactoryWithoutScript, - HubDatasetModuleFactoryWithScript, - LocalDatasetModuleFactoryWithoutScript, - LocalDatasetModuleFactoryWithScript, - PackagedDatasetModuleFactory, + (CachedDatasetModuleFactory, False), + (HubDatasetModuleFactoryWithoutScript, True), + (HubDatasetModuleFactoryWithScript, True), + (LocalDatasetModuleFactoryWithoutScript, False), + (LocalDatasetModuleFactoryWithScript, False), + (PackagedDatasetModuleFactory, False), ], ) -def test_module_factories(factory_class): +def test_module_factories(factory_class, requires_commit_hash): name = "dummy_name" - factory = factory_class(name) + if requires_commit_hash: + factory = factory_class(name, commit_hash="foo") + else: + factory = factory_class(name) assert factory.name == name