diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py
index 1304a971667..97dd6fecf52 100644
--- a/src/datasets/utils/py_utils.py
+++ b/src/datasets/utils/py_utils.py
@@ -367,12 +367,15 @@ def _single_map_nested(args):
# Singleton first to spare some computation
if not isinstance(data_struct, dict) and not isinstance(data_struct, types):
- return function(data_struct)
+ if batched:
+ return function([data_struct])[0]
+ else:
+ return function(data_struct)
if (
batched
and not isinstance(data_struct, dict)
and isinstance(data_struct, types)
- and all(not isinstance(v, types) for v in data_struct)
+ and all(not isinstance(v, (dict, types)) for v in data_struct)
):
return [mapped_item for batch in iter_batched(data_struct, batch_size) for mapped_item in function(batch)]
@@ -450,11 +453,11 @@ def map_nested(
batched (`bool`, defaults to `False`):
Provide batch of items to `function`.
-
+
batch_size (`int`, *optional*, defaults to `1000`):
Number of items per batch provided to `function` if `batched=True`.
If `batch_size <= 0` or `batch_size == None`, provide the full iterable as a single batch to `function`.
-
+
types (`tuple`, *optional*): Additional types (besides `dict` values) to apply `function` recursively to their
elements.
disable_tqdm (`bool`, default `True`): Whether to disable the tqdm progressbar.
diff --git a/tests/test_download_manager.py b/tests/test_download_manager.py
index 51ef61fefeb..a99349efbd2 100644
--- a/tests/test_download_manager.py
+++ b/tests/test_download_manager.py
@@ -8,6 +8,7 @@
from datasets.download.download_manager import DownloadManager
from datasets.download.streaming_download_manager import StreamingDownloadManager
from datasets.utils.file_utils import hash_url_to_filename, xopen
+from datasets.utils.py_utils import NestedDataStructure
URL = "http://www.mocksite.com/file1.txt"
@@ -28,19 +29,15 @@ def mock_request(*args, **kwargs):
return MockResponse()
-@pytest.mark.parametrize("urls_type", [str, list, dict])
+@pytest.mark.parametrize("urls_type", ["str", "list", "dict", "dict_of_dict"])
def test_download_manager_download(urls_type, tmp_path, monkeypatch):
import requests
monkeypatch.setattr(requests, "request", mock_request)
url = URL
- if issubclass(urls_type, str):
- urls = url
- elif issubclass(urls_type, list):
- urls = [url]
- elif issubclass(urls_type, dict):
- urls = {"train": url}
+ urls_types = {"str": url, "list": [url], "dict": {"train": url}, "dict_of_dict": {"train": {"en": url}}}
+ urls = urls_types[urls_type]
dataset_name = "dummy"
cache_subdir = "downloads"
cache_dir_root = tmp_path
@@ -50,29 +47,29 @@ def test_download_manager_download(urls_type, tmp_path, monkeypatch):
)
dl_manager = DownloadManager(dataset_name=dataset_name, download_config=download_config)
downloaded_paths = dl_manager.download(urls)
- input_urls = urls
- for downloaded_paths in [downloaded_paths]:
- if isinstance(urls, str):
- downloaded_paths = [downloaded_paths]
- input_urls = [urls]
- elif isinstance(urls, dict):
- assert "train" in downloaded_paths.keys()
- downloaded_paths = downloaded_paths.values()
- input_urls = urls.values()
- assert downloaded_paths
- for downloaded_path, input_url in zip(downloaded_paths, input_urls):
- assert downloaded_path == dl_manager.downloaded_paths[input_url]
- downloaded_path = Path(downloaded_path)
- parts = downloaded_path.parts
- assert parts[-1] == HASH
- assert parts[-2] == cache_subdir
- assert downloaded_path.exists()
- content = downloaded_path.read_text()
- assert content == CONTENT
- metadata_downloaded_path = downloaded_path.with_suffix(".json")
- assert metadata_downloaded_path.exists()
- metadata_content = json.loads(metadata_downloaded_path.read_text())
- assert metadata_content == {"url": URL, "etag": None}
+ assert isinstance(downloaded_paths, type(urls))
+ if "urls_type".startswith("list"):
+ assert len(downloaded_paths) == len(urls)
+ elif "urls_type".startswith("dict"):
+ assert downloaded_paths.keys() == urls.keys()
+ if "urls_type" == "dict_of_dict":
+ key = list(urls.keys())[0]
+ assert isinstance(downloaded_paths[key], dict)
+ assert downloaded_paths[key].keys() == urls[key].keys()
+ for downloaded_path, url in zip(
+ NestedDataStructure(downloaded_paths).flatten(), NestedDataStructure(urls).flatten()
+ ):
+ downloaded_path = Path(downloaded_path)
+ parts = downloaded_path.parts
+ assert parts[-1] == HASH
+ assert parts[-2] == cache_subdir
+ assert downloaded_path.exists()
+ content = downloaded_path.read_text()
+ assert content == CONTENT
+ metadata_downloaded_path = downloaded_path.with_suffix(".json")
+ assert metadata_downloaded_path.exists()
+ metadata_content = json.loads(metadata_downloaded_path.read_text())
+ assert metadata_content == {"url": URL, "etag": None}
@pytest.mark.parametrize("paths_type", [str, list, dict])
diff --git a/tests/test_py_utils.py b/tests/test_py_utils.py
index 4a5ab1fb05d..1b618987392 100644
--- a/tests/test_py_utils.py
+++ b/tests/test_py_utils.py
@@ -29,49 +29,39 @@ def add_one(i): # picklable for multiprocessing
return i + 1
+def add_one_to_batch(batch): # picklable for multiprocessing
+ return [i + 1 for i in batch]
+
+
@dataclass
class A:
x: int
y: str
+@pytest.mark.parametrize("batched, function", [(False, add_one), (True, add_one_to_batch)])
+@pytest.mark.parametrize("num_proc", [None, 2])
+@pytest.mark.parametrize(
+ "data_struct, expected_result",
+ [
+ ({}, {}),
+ ([], []),
+ (1, 2),
+ ([1, 2], [2, 3]),
+ ({"a": 1, "b": 2}, {"a": 2, "b": 3}),
+ ({"a": [1, 2], "b": [3, 4]}, {"a": [2, 3], "b": [4, 5]}),
+ ({"a": {"1": 1}, "b": {"2": 2}}, {"a": {"1": 2}, "b": {"2": 3}}),
+ ({"a": 1, "b": [2, 3], "c": {"1": 4}}, {"a": 2, "b": [3, 4], "c": {"1": 5}}),
+ ({"a": 1, "b": 2, "c": 3, "d": 4}, {"a": 2, "b": 3, "c": 4, "d": 5}),
+ ],
+)
+def test_map_nested(data_struct, expected_result, num_proc, batched, function):
+ assert map_nested(function, data_struct, num_proc=num_proc, batched=batched) == expected_result
+
+
class PyUtilsTest(TestCase):
def test_map_nested(self):
- s1 = {}
- s2 = []
- s3 = 1
- s4 = [1, 2]
- s5 = {"a": 1, "b": 2}
- s6 = {"a": [1, 2], "b": [3, 4]}
- s7 = {"a": {"1": 1}, "b": 2}
- s8 = {"a": 1, "b": 2, "c": 3, "d": 4}
- expected_map_nested_s1 = {}
- expected_map_nested_s2 = []
- expected_map_nested_s3 = 2
- expected_map_nested_s4 = [2, 3]
- expected_map_nested_s5 = {"a": 2, "b": 3}
- expected_map_nested_s6 = {"a": [2, 3], "b": [4, 5]}
- expected_map_nested_s7 = {"a": {"1": 2}, "b": 3}
- expected_map_nested_s8 = {"a": 2, "b": 3, "c": 4, "d": 5}
- self.assertEqual(map_nested(add_one, s1), expected_map_nested_s1)
- self.assertEqual(map_nested(add_one, s2), expected_map_nested_s2)
- self.assertEqual(map_nested(add_one, s3), expected_map_nested_s3)
- self.assertEqual(map_nested(add_one, s4), expected_map_nested_s4)
- self.assertEqual(map_nested(add_one, s5), expected_map_nested_s5)
- self.assertEqual(map_nested(add_one, s6), expected_map_nested_s6)
- self.assertEqual(map_nested(add_one, s7), expected_map_nested_s7)
- self.assertEqual(map_nested(add_one, s8), expected_map_nested_s8)
-
num_proc = 2
- self.assertEqual(map_nested(add_one, s1, num_proc=num_proc), expected_map_nested_s1)
- self.assertEqual(map_nested(add_one, s2, num_proc=num_proc), expected_map_nested_s2)
- self.assertEqual(map_nested(add_one, s3, num_proc=num_proc), expected_map_nested_s3)
- self.assertEqual(map_nested(add_one, s4, num_proc=num_proc), expected_map_nested_s4)
- self.assertEqual(map_nested(add_one, s5, num_proc=num_proc), expected_map_nested_s5)
- self.assertEqual(map_nested(add_one, s6, num_proc=num_proc), expected_map_nested_s6)
- self.assertEqual(map_nested(add_one, s7, num_proc=num_proc), expected_map_nested_s7)
- self.assertEqual(map_nested(add_one, s8, num_proc=num_proc), expected_map_nested_s8)
-
sn1 = {"a": np.eye(2), "b": np.zeros(3), "c": np.ones(2)}
expected_map_nested_sn1_sum = {"a": 2, "b": 0, "c": 2}
expected_map_nested_sn1_int = {