diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 1304a971667..97dd6fecf52 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -367,12 +367,15 @@ def _single_map_nested(args): # Singleton first to spare some computation if not isinstance(data_struct, dict) and not isinstance(data_struct, types): - return function(data_struct) + if batched: + return function([data_struct])[0] + else: + return function(data_struct) if ( batched and not isinstance(data_struct, dict) and isinstance(data_struct, types) - and all(not isinstance(v, types) for v in data_struct) + and all(not isinstance(v, (dict, types)) for v in data_struct) ): return [mapped_item for batch in iter_batched(data_struct, batch_size) for mapped_item in function(batch)] @@ -450,11 +453,11 @@ def map_nested( batched (`bool`, defaults to `False`): Provide batch of items to `function`. - + batch_size (`int`, *optional*, defaults to `1000`): Number of items per batch provided to `function` if `batched=True`. If `batch_size <= 0` or `batch_size == None`, provide the full iterable as a single batch to `function`. - + types (`tuple`, *optional*): Additional types (besides `dict` values) to apply `function` recursively to their elements. disable_tqdm (`bool`, default `True`): Whether to disable the tqdm progressbar. diff --git a/tests/test_download_manager.py b/tests/test_download_manager.py index 51ef61fefeb..a99349efbd2 100644 --- a/tests/test_download_manager.py +++ b/tests/test_download_manager.py @@ -8,6 +8,7 @@ from datasets.download.download_manager import DownloadManager from datasets.download.streaming_download_manager import StreamingDownloadManager from datasets.utils.file_utils import hash_url_to_filename, xopen +from datasets.utils.py_utils import NestedDataStructure URL = "http://www.mocksite.com/file1.txt" @@ -28,19 +29,15 @@ def mock_request(*args, **kwargs): return MockResponse() -@pytest.mark.parametrize("urls_type", [str, list, dict]) +@pytest.mark.parametrize("urls_type", ["str", "list", "dict", "dict_of_dict"]) def test_download_manager_download(urls_type, tmp_path, monkeypatch): import requests monkeypatch.setattr(requests, "request", mock_request) url = URL - if issubclass(urls_type, str): - urls = url - elif issubclass(urls_type, list): - urls = [url] - elif issubclass(urls_type, dict): - urls = {"train": url} + urls_types = {"str": url, "list": [url], "dict": {"train": url}, "dict_of_dict": {"train": {"en": url}}} + urls = urls_types[urls_type] dataset_name = "dummy" cache_subdir = "downloads" cache_dir_root = tmp_path @@ -50,29 +47,29 @@ def test_download_manager_download(urls_type, tmp_path, monkeypatch): ) dl_manager = DownloadManager(dataset_name=dataset_name, download_config=download_config) downloaded_paths = dl_manager.download(urls) - input_urls = urls - for downloaded_paths in [downloaded_paths]: - if isinstance(urls, str): - downloaded_paths = [downloaded_paths] - input_urls = [urls] - elif isinstance(urls, dict): - assert "train" in downloaded_paths.keys() - downloaded_paths = downloaded_paths.values() - input_urls = urls.values() - assert downloaded_paths - for downloaded_path, input_url in zip(downloaded_paths, input_urls): - assert downloaded_path == dl_manager.downloaded_paths[input_url] - downloaded_path = Path(downloaded_path) - parts = downloaded_path.parts - assert parts[-1] == HASH - assert parts[-2] == cache_subdir - assert downloaded_path.exists() - content = downloaded_path.read_text() - assert content == CONTENT - metadata_downloaded_path = downloaded_path.with_suffix(".json") - assert metadata_downloaded_path.exists() - metadata_content = json.loads(metadata_downloaded_path.read_text()) - assert metadata_content == {"url": URL, "etag": None} + assert isinstance(downloaded_paths, type(urls)) + if "urls_type".startswith("list"): + assert len(downloaded_paths) == len(urls) + elif "urls_type".startswith("dict"): + assert downloaded_paths.keys() == urls.keys() + if "urls_type" == "dict_of_dict": + key = list(urls.keys())[0] + assert isinstance(downloaded_paths[key], dict) + assert downloaded_paths[key].keys() == urls[key].keys() + for downloaded_path, url in zip( + NestedDataStructure(downloaded_paths).flatten(), NestedDataStructure(urls).flatten() + ): + downloaded_path = Path(downloaded_path) + parts = downloaded_path.parts + assert parts[-1] == HASH + assert parts[-2] == cache_subdir + assert downloaded_path.exists() + content = downloaded_path.read_text() + assert content == CONTENT + metadata_downloaded_path = downloaded_path.with_suffix(".json") + assert metadata_downloaded_path.exists() + metadata_content = json.loads(metadata_downloaded_path.read_text()) + assert metadata_content == {"url": URL, "etag": None} @pytest.mark.parametrize("paths_type", [str, list, dict]) diff --git a/tests/test_py_utils.py b/tests/test_py_utils.py index 4a5ab1fb05d..1b618987392 100644 --- a/tests/test_py_utils.py +++ b/tests/test_py_utils.py @@ -29,49 +29,39 @@ def add_one(i): # picklable for multiprocessing return i + 1 +def add_one_to_batch(batch): # picklable for multiprocessing + return [i + 1 for i in batch] + + @dataclass class A: x: int y: str +@pytest.mark.parametrize("batched, function", [(False, add_one), (True, add_one_to_batch)]) +@pytest.mark.parametrize("num_proc", [None, 2]) +@pytest.mark.parametrize( + "data_struct, expected_result", + [ + ({}, {}), + ([], []), + (1, 2), + ([1, 2], [2, 3]), + ({"a": 1, "b": 2}, {"a": 2, "b": 3}), + ({"a": [1, 2], "b": [3, 4]}, {"a": [2, 3], "b": [4, 5]}), + ({"a": {"1": 1}, "b": {"2": 2}}, {"a": {"1": 2}, "b": {"2": 3}}), + ({"a": 1, "b": [2, 3], "c": {"1": 4}}, {"a": 2, "b": [3, 4], "c": {"1": 5}}), + ({"a": 1, "b": 2, "c": 3, "d": 4}, {"a": 2, "b": 3, "c": 4, "d": 5}), + ], +) +def test_map_nested(data_struct, expected_result, num_proc, batched, function): + assert map_nested(function, data_struct, num_proc=num_proc, batched=batched) == expected_result + + class PyUtilsTest(TestCase): def test_map_nested(self): - s1 = {} - s2 = [] - s3 = 1 - s4 = [1, 2] - s5 = {"a": 1, "b": 2} - s6 = {"a": [1, 2], "b": [3, 4]} - s7 = {"a": {"1": 1}, "b": 2} - s8 = {"a": 1, "b": 2, "c": 3, "d": 4} - expected_map_nested_s1 = {} - expected_map_nested_s2 = [] - expected_map_nested_s3 = 2 - expected_map_nested_s4 = [2, 3] - expected_map_nested_s5 = {"a": 2, "b": 3} - expected_map_nested_s6 = {"a": [2, 3], "b": [4, 5]} - expected_map_nested_s7 = {"a": {"1": 2}, "b": 3} - expected_map_nested_s8 = {"a": 2, "b": 3, "c": 4, "d": 5} - self.assertEqual(map_nested(add_one, s1), expected_map_nested_s1) - self.assertEqual(map_nested(add_one, s2), expected_map_nested_s2) - self.assertEqual(map_nested(add_one, s3), expected_map_nested_s3) - self.assertEqual(map_nested(add_one, s4), expected_map_nested_s4) - self.assertEqual(map_nested(add_one, s5), expected_map_nested_s5) - self.assertEqual(map_nested(add_one, s6), expected_map_nested_s6) - self.assertEqual(map_nested(add_one, s7), expected_map_nested_s7) - self.assertEqual(map_nested(add_one, s8), expected_map_nested_s8) - num_proc = 2 - self.assertEqual(map_nested(add_one, s1, num_proc=num_proc), expected_map_nested_s1) - self.assertEqual(map_nested(add_one, s2, num_proc=num_proc), expected_map_nested_s2) - self.assertEqual(map_nested(add_one, s3, num_proc=num_proc), expected_map_nested_s3) - self.assertEqual(map_nested(add_one, s4, num_proc=num_proc), expected_map_nested_s4) - self.assertEqual(map_nested(add_one, s5, num_proc=num_proc), expected_map_nested_s5) - self.assertEqual(map_nested(add_one, s6, num_proc=num_proc), expected_map_nested_s6) - self.assertEqual(map_nested(add_one, s7, num_proc=num_proc), expected_map_nested_s7) - self.assertEqual(map_nested(add_one, s8, num_proc=num_proc), expected_map_nested_s8) - sn1 = {"a": np.eye(2), "b": np.zeros(3), "c": np.ones(2)} expected_map_nested_sn1_sum = {"a": 2, "b": 0, "c": 2} expected_map_nested_sn1_int = {