From a3251f7fd232b39f61e2ae62939982ce3bef90b6 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 17 Oct 2024 15:52:09 +0200 Subject: [PATCH] actually do a single call in dataset_module_factory --- src/datasets/load.py | 232 +++++++++++++------------- src/datasets/utils/_dataset_viewer.py | 14 +- src/datasets/utils/metadata.py | 4 +- 3 files changed, 130 insertions(+), 120 deletions(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index 6178d91638f..6d073886b3d 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -36,8 +36,14 @@ import requests import yaml from fsspec.core import url_to_fs -from huggingface_hub import DatasetCard, DatasetCardData, HfApi, HfFileSystem -from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError, get_session +from huggingface_hub import DatasetCard, DatasetCardData, HfApi +from huggingface_hub.utils import ( + EntryNotFoundError, + GatedRepoError, + RepositoryNotFoundError, + RevisionNotFoundError, + get_session, +) from . import __version__, config from .arrow_dataset import Dataset @@ -82,7 +88,7 @@ relative_to_absolute_path, url_or_path_join, ) -from .utils.hub import check_auth, hf_dataset_url +from .utils.hub import hf_dataset_url from .utils.info_utils import VerificationMode, is_small_dataset from .utils.logging import get_logger from .utils.metadata import MetadataConfigs @@ -974,18 +980,20 @@ class HubDatasetModuleFactoryWithoutScript(_DatasetModuleFactory): def __init__( self, name: str, - revision: Optional[Union[str, Version]] = None, + commit_hash: Optional[str] = None, data_dir: Optional[str] = None, data_files: Optional[Union[str, List, Dict]] = None, download_config: Optional[DownloadConfig] = None, download_mode: Optional[Union[DownloadMode, str]] = None, + use_exported_dataset_infos: bool = False, ): self.name = name - self.revision = revision + self.commit_hash = commit_hash self.data_files = data_files self.data_dir = data_dir self.download_config = download_config or DownloadConfig() self.download_mode = download_mode + self.use_exported_dataset_infos = use_exported_dataset_infos increase_load_count(name) def get_module(self) -> DatasetModule: @@ -1002,24 +1010,18 @@ def get_module(self) -> DatasetModule: repo_id=self.name, filename=config.REPOCARD_FILENAME, repo_type="dataset", - revision=self.revision, + revision=self.commit_hash, proxies=self.download_config.proxies, ) - commit_hash = os.path.dirname(dataset_readme_path) dataset_card_data = DatasetCard.load(dataset_readme_path).data except FileNotFoundError: - commit_hash = api.dataset_info( - self.name, - revision=self.revision, - timeout=100.0, - ).sha dataset_card_data = DatasetCardData() download_config = self.download_config.copy() if download_config.download_desc is None: download_config.download_desc = "Downloading standalone yaml" try: standalone_yaml_path = cached_path( - hf_dataset_url(self.name, config.REPOYAML_FILENAME, revision=commit_hash), + hf_dataset_url(self.name, config.REPOYAML_FILENAME, revision=self.commit_hash), download_config=download_config, ) with open(standalone_yaml_path, "r", encoding="utf-8") as f: @@ -1030,18 +1032,13 @@ def get_module(self) -> DatasetModule: dataset_card_data = DatasetCardData(**_dataset_card_data_dict) except FileNotFoundError: pass - base_path = f"hf://datasets/{self.name}@{commit_hash}/{self.data_dir or ''}".rstrip("/") + base_path = f"hf://datasets/{self.name}@{self.commit_hash}/{self.data_dir or ''}".rstrip("/") metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data) dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data) - # Use the infos from the parquet export except in some cases: - if self.data_dir or self.data_files or (self.revision and self.revision != "main"): - use_exported_dataset_infos = False - else: - use_exported_dataset_infos = True - if config.USE_PARQUET_EXPORT and use_exported_dataset_infos: + if config.USE_PARQUET_EXPORT and self.use_exported_dataset_infos: try: exported_dataset_infos = _dataset_viewer.get_exported_dataset_infos( - dataset=self.name, revision=self.revision, token=self.download_config.token + dataset=self.name, commit_hash=self.commit_hash, token=self.download_config.token ) exported_dataset_infos = DatasetInfosDict( { @@ -1114,7 +1111,7 @@ def get_module(self) -> DatasetModule: ] default_config_name = None builder_kwargs = { - "base_path": hf_dataset_url(self.name, "", revision=commit_hash).rstrip("/"), + "base_path": hf_dataset_url(self.name, "", revision=self.commit_hash).rstrip("/"), "repo_id": self.name, "dataset_name": camelcase_to_snakecase(Path(self.name).name), } @@ -1126,7 +1123,7 @@ def get_module(self) -> DatasetModule: try: # this file is deprecated and was created automatically in old versions of push_to_hub dataset_infos_path = cached_path( - hf_dataset_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=commit_hash), + hf_dataset_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=self.commit_hash), download_config=download_config, ) with open(dataset_infos_path, encoding="utf-8") as f: @@ -1149,7 +1146,7 @@ def get_module(self) -> DatasetModule: return DatasetModule( module_path, - commit_hash, + self.commit_hash, builder_kwargs, dataset_infos=dataset_infos, builder_configs_parameters=BuilderConfigsParameters( @@ -1168,20 +1165,20 @@ class HubDatasetModuleFactoryWithParquetExport(_DatasetModuleFactory): def __init__( self, name: str, - revision: Optional[str] = None, + commit_hash: Optional[str] = None, download_config: Optional[DownloadConfig] = None, ): self.name = name - self.revision = revision + self.commit_hash = commit_hash self.download_config = download_config or DownloadConfig() increase_load_count(name) def get_module(self) -> DatasetModule: exported_parquet_files = _dataset_viewer.get_exported_parquet_files( - dataset=self.name, revision=self.revision, token=self.download_config.token + dataset=self.name, commit_hash=self.commit_hash, token=self.download_config.token ) exported_dataset_infos = _dataset_viewer.get_exported_dataset_infos( - dataset=self.name, revision=self.revision, token=self.download_config.token + dataset=self.name, commit_hash=self.commit_hash, token=self.download_config.token ) dataset_infos = DatasetInfosDict( { @@ -1189,15 +1186,8 @@ def get_module(self) -> DatasetModule: for config_name in exported_dataset_infos } ) - hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info( - self.name, - revision="refs/convert/parquet", - token=self.download_config.token, - timeout=100.0, - ) - revision = hfh_dataset_info.sha # fix the revision in case there are new commits in the meantime metadata_configs = MetadataConfigs._from_exported_parquet_files_and_dataset_infos( - revision=revision, exported_parquet_files=exported_parquet_files, dataset_infos=dataset_infos + commit_hash=self.commit_hash, exported_parquet_files=exported_parquet_files, dataset_infos=dataset_infos ) module_path, _ = _PACKAGED_DATASETS_MODULES["parquet"] builder_configs, default_config_name = create_builder_configs_from_metadata_configs( @@ -1206,7 +1196,6 @@ def get_module(self) -> DatasetModule: supports_metadata=False, download_config=self.download_config, ) - hash = self.revision builder_kwargs = { "repo_id": self.name, "dataset_name": camelcase_to_snakecase(Path(self.name).name), @@ -1214,7 +1203,7 @@ def get_module(self) -> DatasetModule: return DatasetModule( module_path, - hash, + self.commit_hash, builder_kwargs, dataset_infos=dataset_infos, builder_configs_parameters=BuilderConfigsParameters( @@ -1234,14 +1223,14 @@ class HubDatasetModuleFactoryWithScript(_DatasetModuleFactory): def __init__( self, name: str, - revision: Optional[Union[str, Version]] = None, + commit_hash: Optional[str] = None, download_config: Optional[DownloadConfig] = None, download_mode: Optional[Union[DownloadMode, str]] = None, dynamic_modules_path: Optional[str] = None, trust_remote_code: Optional[bool] = None, ): self.name = name - self.revision = revision + self.commit_hash = commit_hash self.download_config = download_config or DownloadConfig() self.download_mode = download_mode self.dynamic_modules_path = dynamic_modules_path @@ -1249,14 +1238,14 @@ def __init__( increase_load_count(name) def download_loading_script(self) -> str: - file_path = hf_dataset_url(self.name, self.name.split("/")[-1] + ".py", revision=self.revision) + file_path = hf_dataset_url(self.name, self.name.split("/")[-1] + ".py", revision=self.commit_hash) download_config = self.download_config.copy() if download_config.download_desc is None: download_config.download_desc = "Downloading builder script" return cached_path(file_path, download_config=download_config) def download_dataset_infos_file(self) -> str: - dataset_infos = hf_dataset_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=self.revision) + dataset_infos = hf_dataset_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=self.commit_hash) # Download the dataset infos file if available download_config = self.download_config.copy() if download_config.download_desc is None: @@ -1270,7 +1259,7 @@ def download_dataset_infos_file(self) -> str: return None def download_dataset_readme_file(self) -> str: - readme_url = hf_dataset_url(self.name, config.REPOCARD_FILENAME, revision=self.revision) + readme_url = hf_dataset_url(self.name, config.REPOCARD_FILENAME, revision=self.commit_hash) # Download the dataset infos file if available download_config = self.download_config.copy() if download_config.download_desc is None: @@ -1299,7 +1288,7 @@ def get_module(self) -> DatasetModule: imports = get_imports(local_path) local_imports, library_imports = _download_additional_modules( name=self.name, - base_path=hf_dataset_url(self.name, "", revision=self.revision), + base_path=hf_dataset_url(self.name, "", revision=self.commit_hash), imports=imports, download_config=self.download_config, ) @@ -1576,78 +1565,97 @@ def dataset_module_factory( ).get_module() # Try remotely elif is_relative_path(path) and path.count("/") <= 1: + # Get the Dataset Card + get the revision + check authentication all at in one call + # We fix the commit_hash in case there are new commits in the meantime + api = HfApi( + endpoint=config.HF_ENDPOINT, + token=download_config.token, + library_name="datasets", + library_version=__version__, + user_agent=get_datasets_user_agent(download_config.user_agent), + ) try: _raise_if_offline_mode_is_enabled() - hf_api = HfApi(config.HF_ENDPOINT) - try: - dataset_info = hf_api.dataset_info( - repo_id=path, - revision=revision, - token=download_config.token, - timeout=100.0, - ) - except ( - OfflineModeIsEnabled, - requests.exceptions.ConnectTimeout, - requests.exceptions.ConnectionError, - ) as e: - raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({e.__class__.__name__})") from e - except RevisionNotFoundError as e: - raise DatasetNotFoundError( - f"Revision '{revision}' doesn't exist for dataset '{path}' on the Hub." - ) from e - except RepositoryNotFoundError as e: - raise DatasetNotFoundError(f"Dataset '{path}' doesn't exist on the Hub or cannot be accessed.") from e - if dataset_info.gated: + dataset_readme_path = api.hf_hub_download( + repo_id=path, + filename=config.REPOCARD_FILENAME, + repo_type="dataset", + revision=revision, + proxies=download_config.proxies, + ) + commit_hash = os.path.basename(os.path.dirname(dataset_readme_path)) + except EntryNotFoundError: + commit_hash = api.dataset_info( + path, + revision=revision, + timeout=100.0, + ).sha + except ( + OfflineModeIsEnabled, + requests.exceptions.ConnectTimeout, + requests.exceptions.ConnectionError, + ) as e: + raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({e.__class__.__name__})") from e + except GatedRepoError as e: + message = f"Dataset '{path}' is a gated dataset on the Hub." + if "401 Client Error" in str(e): + message += " You must be authenticated to access it." + elif "403 Client Error" in str(e): + message += f" Visit the dataset page at https://huggingface.co/datasets/{path} to ask for access." + raise DatasetNotFoundError(message) from e + except RevisionNotFoundError as e: + raise DatasetNotFoundError(f"Revision '{revision}' doesn't exist for dataset '{path}' on the Hub.") from e + except RepositoryNotFoundError as e: + raise DatasetNotFoundError(f"Dataset '{path}' doesn't exist on the Hub or cannot be accessed.") from e + try: + dataset_script_path = api.hf_hub_download( + repo_id=path, + filename=filename, + repo_type="dataset", + revision=commit_hash, + proxies=download_config.proxies, + ) + if _require_custom_configs or (revision and revision != "main"): + can_load_config_from_parquet_export = False + elif _require_default_config_name: + with open(dataset_script_path, "r", encoding="utf-8") as f: + can_load_config_from_parquet_export = "DEFAULT_CONFIG_NAME" not in f.read() + else: + can_load_config_from_parquet_export = True + if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export: + # If the parquet export is ready (parquet files + info available for the current sha), we can use it instead + # This fails when the dataset has multiple configs and a default config and + # the user didn't specify a configuration name (_require_default_config_name=True). try: - check_auth(hf_api, repo_id=path, token=download_config.token) - except GatedRepoError as e: - message = f"Dataset '{path}' is a gated dataset on the Hub." - if "401 Client Error" in str(e): - message += " You must be authenticated to access it." - elif "403 Client Error" in str(e): - message += ( - f" Visit the dataset page at https://huggingface.co/datasets/{path} to ask for access." - ) - raise DatasetNotFoundError(message) from e - - if filename in [sibling.rfilename for sibling in dataset_info.siblings]: # contains a dataset script - fs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token) - if _require_custom_configs or (revision and revision != "main"): - can_load_config_from_parquet_export = False - elif _require_default_config_name: - with fs.open(f"datasets/{path}/{filename}", "r", encoding="utf-8") as f: - can_load_config_from_parquet_export = "DEFAULT_CONFIG_NAME" not in f.read() - else: - can_load_config_from_parquet_export = True - if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export: - # If the parquet export is ready (parquet files + info available for the current sha), we can use it instead - # This fails when the dataset has multiple configs and a default config and - # the user didn't specify a configuration name (_require_default_config_name=True). - try: - return HubDatasetModuleFactoryWithParquetExport( - path, download_config=download_config, revision=dataset_info.sha - ).get_module() - except _dataset_viewer.DatasetViewerError: - pass - # Otherwise we must use the dataset script if the user trusts it - return HubDatasetModuleFactoryWithScript( - path, - revision=revision, - download_config=download_config, - download_mode=download_mode, - dynamic_modules_path=dynamic_modules_path, - trust_remote_code=trust_remote_code, - ).get_module() + return HubDatasetModuleFactoryWithParquetExport( + path, download_config=download_config, commit_hash=commit_hash + ).get_module() + except _dataset_viewer.DatasetViewerError: + pass + # Otherwise we must use the dataset script if the user trusts it + return HubDatasetModuleFactoryWithScript( + path, + commit_hash=commit_hash, + download_config=download_config, + download_mode=download_mode, + dynamic_modules_path=dynamic_modules_path, + trust_remote_code=trust_remote_code, + ).get_module() + except EntryNotFoundError: + # Use the infos from the parquet export except in some cases: + if data_dir or data_files or (revision and revision != "main"): + use_exported_dataset_infos = False else: - return HubDatasetModuleFactoryWithoutScript( - path, - revision=revision, - data_dir=data_dir, - data_files=data_files, - download_config=download_config, - download_mode=download_mode, - ).get_module() + use_exported_dataset_infos = True + return HubDatasetModuleFactoryWithoutScript( + path, + commit_hash=commit_hash, + data_dir=data_dir, + data_files=data_files, + download_config=download_config, + download_mode=download_mode, + use_exported_dataset_infos=use_exported_dataset_infos, + ).get_module() except Exception as e1: # All the attempts failed, before raising the error we should check if the module is already cached try: diff --git a/src/datasets/utils/_dataset_viewer.py b/src/datasets/utils/_dataset_viewer.py index b8cf6ea49e1..092741a956c 100644 --- a/src/datasets/utils/_dataset_viewer.py +++ b/src/datasets/utils/_dataset_viewer.py @@ -23,7 +23,9 @@ class DatasetViewerError(DatasetsError): """ -def get_exported_parquet_files(dataset: str, revision: str, token: Optional[Union[str, bool]]) -> List[Dict[str, Any]]: +def get_exported_parquet_files( + dataset: str, commit_hash: str, token: Optional[Union[str, bool]] +) -> List[Dict[str, Any]]: """ Get the dataset exported parquet files Docs: https://huggingface.co/docs/datasets-server/parquet @@ -37,7 +39,7 @@ def get_exported_parquet_files(dataset: str, revision: str, token: Optional[Unio ) parquet_data_files_response.raise_for_status() if "X-Revision" in parquet_data_files_response.headers: - if parquet_data_files_response.headers["X-Revision"] == revision or revision is None: + if parquet_data_files_response.headers["X-Revision"] == commit_hash or commit_hash is None: parquet_data_files_response_json = parquet_data_files_response.json() if ( parquet_data_files_response_json.get("partial") is False @@ -50,7 +52,7 @@ def get_exported_parquet_files(dataset: str, revision: str, token: Optional[Unio logger.debug(f"Parquet export for {dataset} is not completely ready yet.") else: logger.debug( - f"Parquet export for {dataset} is available but outdated (revision='{parquet_data_files_response.headers['X-Revision']}')" + f"Parquet export for {dataset} is available but outdated (commit_hash='{parquet_data_files_response.headers['X-Revision']}')" ) except Exception as e: # noqa catch any exception of the dataset viewer API and consider the parquet export doesn't exist logger.debug(f"No parquet export for {dataset} available ({type(e).__name__}: {e})") @@ -58,7 +60,7 @@ def get_exported_parquet_files(dataset: str, revision: str, token: Optional[Unio def get_exported_dataset_infos( - dataset: str, revision: str, token: Optional[Union[str, bool]] + dataset: str, commit_hash: str, token: Optional[Union[str, bool]] ) -> Dict[str, Dict[str, Any]]: """ Get the dataset information, can be useful to get e.g. the dataset features. @@ -73,7 +75,7 @@ def get_exported_dataset_infos( ) info_response.raise_for_status() if "X-Revision" in info_response.headers: - if info_response.headers["X-Revision"] == revision or revision is None: + if info_response.headers["X-Revision"] == commit_hash or commit_hash is None: info_response = info_response.json() if ( info_response.get("partial") is False @@ -86,7 +88,7 @@ def get_exported_dataset_infos( logger.debug(f"Dataset info for {dataset} is not completely ready yet.") else: logger.debug( - f"Dataset info for {dataset} is available but outdated (revision='{info_response.headers['X-Revision']}')" + f"Dataset info for {dataset} is available but outdated (commit_hash='{info_response.headers['X-Revision']}')" ) except Exception as e: # noqa catch any exception of the dataset viewer API and consider the dataset info doesn't exist logger.debug(f"No dataset info for {dataset} available ({type(e).__name__}: {e})") diff --git a/src/datasets/utils/metadata.py b/src/datasets/utils/metadata.py index fa463272213..7e639f10d18 100644 --- a/src/datasets/utils/metadata.py +++ b/src/datasets/utils/metadata.py @@ -102,7 +102,7 @@ def _raise_if_data_files_field_not_valid(metadata_config: dict): @classmethod def _from_exported_parquet_files_and_dataset_infos( cls, - revision: str, + commit_hash: str, exported_parquet_files: List[Dict[str, Any]], dataset_infos: DatasetInfosDict, ) -> "MetadataConfigs": @@ -112,7 +112,7 @@ def _from_exported_parquet_files_and_dataset_infos( { "split": split_name, "path": [ - parquet_file["url"].replace("refs%2Fconvert%2Fparquet", revision) + parquet_file["url"].replace("refs%2Fconvert%2Fparquet", commit_hash) for parquet_file in parquet_files_for_split ], }