From 798b01cd245e83410893b8d9b3b896d9a4371cb0 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 20 Jun 2024 12:33:42 +0200 Subject: [PATCH 1/2] Move dataset card creation to method for easier overriding --- src/datasets/arrow_dataset.py | 22 +++++++++++++++++++--- src/datasets/dataset_dict.py | 21 ++++++++++++++++++--- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 779091a75af..022139b416e 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -5781,9 +5781,9 @@ def push_to_hub( CommitOperationAdd(path_in_repo=config.DATASETDICT_INFOS_FILENAME, path_or_fileobj=buffer) ) # push to README - DatasetInfosDict({config_name: info_to_dump}).to_dataset_card_data(dataset_card_data) - MetadataConfigs({config_name: metadata_config_to_dump}).to_dataset_card_data(dataset_card_data) - dataset_card = DatasetCard(f"---\n{dataset_card_data}\n---\n") if dataset_card is None else dataset_card + dataset_card = self._create_dataset_card( + dataset_card_data, dataset_card, config_name, info_to_dump, metadata_config_to_dump + ) additions.append( CommitOperationAdd(path_in_repo=config.REPOCARD_FILENAME, path_or_fileobj=str(dataset_card).encode()) ) @@ -5826,6 +5826,22 @@ def push_to_hub( ) return commit_info + def _create_dataset_card( + self, + dataset_card_data: DatasetCardData, + dataset_card: Optional[DatasetCard], + config_name: str, + info_to_dump: DatasetInfo, + metadata_config_to_dump: MetadataConfigs, + ) -> DatasetCard: + if dataset_card: + return dataset_card + + DatasetInfosDict({config_name: info_to_dump}).to_dataset_card_data(dataset_card_data) + MetadataConfigs({config_name: metadata_config_to_dump}).to_dataset_card_data(dataset_card_data) + return DatasetCard(f"---\n{dataset_card_data}\n---\n") + + @transmit_format @fingerprint_transform(inplace=False) def add_column(self, name: str, column: Union[list, np.array], new_fingerprint: str): diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 9e3e8543b77..c13913eb57f 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1820,9 +1820,9 @@ def push_to_hub( CommitOperationAdd(path_in_repo=config.DATASETDICT_INFOS_FILENAME, path_or_fileobj=buffer) ) # push to README - DatasetInfosDict({config_name: info_to_dump}).to_dataset_card_data(dataset_card_data) - MetadataConfigs({config_name: metadata_config_to_dump}).to_dataset_card_data(dataset_card_data) - dataset_card = DatasetCard(f"---\n{dataset_card_data}\n---\n") if dataset_card is None else dataset_card + dataset_card = self._create_dataset_card( + dataset_card_data, dataset_card, config_name, info_to_dump, metadata_config_to_dump + ) additions.append( CommitOperationAdd(path_in_repo=config.REPOCARD_FILENAME, path_or_fileobj=str(dataset_card).encode()) ) @@ -1865,6 +1865,21 @@ def push_to_hub( ) return commit_info + def _create_dataset_card( + self, + dataset_card_data: DatasetCardData, + dataset_card: Optional[DatasetCard], + config_name: str, + info_to_dump: DatasetInfo, + metadata_config_to_dump: MetadataConfigs, + ) -> DatasetCard: + if dataset_card: + return dataset_card + + DatasetInfosDict({config_name: info_to_dump}).to_dataset_card_data(dataset_card_data) + MetadataConfigs({config_name: metadata_config_to_dump}).to_dataset_card_data(dataset_card_data) + return DatasetCard(f"---\n{dataset_card_data}\n---\n") + class IterableDatasetDict(dict): def __repr__(self): From 7a9ccbcfb898f87e02267d8ead8a69a8505b4777 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 20 Jun 2024 12:49:10 +0200 Subject: [PATCH 2/2] Reformat --- src/datasets/arrow_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 022139b416e..ddea54cabe3 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -5841,7 +5841,6 @@ def _create_dataset_card( MetadataConfigs({config_name: metadata_config_to_dump}).to_dataset_card_data(dataset_card_data) return DatasetCard(f"---\n{dataset_card_data}\n---\n") - @transmit_format @fingerprint_transform(inplace=False) def add_column(self, name: str, column: Union[list, np.array], new_fingerprint: str):