Skip to content

Commit

Permalink
Fix download for dict of dicts of URLs (#6871)
Browse files Browse the repository at this point in the history
* Test DownloadManager.download with dict of dicts

* Test map_nested when batched

* Fix _single_map_nested when batched

* Fix versionadded to 2.19.0 in map_nested docstring
  • Loading branch information
albertvillanova committed May 6, 2024
1 parent 0d3c746 commit a5a76a4
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 68 deletions.
11 changes: 7 additions & 4 deletions src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,15 @@ def _single_map_nested(args):

# Singleton first to spare some computation
if not isinstance(data_struct, dict) and not isinstance(data_struct, types):
return function(data_struct)
if batched:
return function([data_struct])[0]
else:
return function(data_struct)
if (
batched
and not isinstance(data_struct, dict)
and isinstance(data_struct, types)
and all(not isinstance(v, types) for v in data_struct)
and all(not isinstance(v, (dict, types)) for v in data_struct)
):
return [mapped_item for batch in iter_batched(data_struct, batch_size) for mapped_item in function(batch)]

Expand Down Expand Up @@ -450,11 +453,11 @@ def map_nested(
<Added version="2.5.0"/>
batched (`bool`, defaults to `False`):
Provide batch of items to `function`.
<Added version="2.18.1"/>
<Added version="2.19.0"/>
batch_size (`int`, *optional*, defaults to `1000`):
Number of items per batch provided to `function` if `batched=True`.
If `batch_size <= 0` or `batch_size == None`, provide the full iterable as a single batch to `function`.
<Added version="2.18.1"/>
<Added version="2.19.0"/>
types (`tuple`, *optional*): Additional types (besides `dict` values) to apply `function` recursively to their
elements.
disable_tqdm (`bool`, default `True`): Whether to disable the tqdm progressbar.
Expand Down
57 changes: 27 additions & 30 deletions tests/test_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datasets.download.download_manager import DownloadManager
from datasets.download.streaming_download_manager import StreamingDownloadManager
from datasets.utils.file_utils import hash_url_to_filename, xopen
from datasets.utils.py_utils import NestedDataStructure


URL = "http://www.mocksite.com/file1.txt"
Expand All @@ -28,19 +29,15 @@ def mock_request(*args, **kwargs):
return MockResponse()


@pytest.mark.parametrize("urls_type", [str, list, dict])
@pytest.mark.parametrize("urls_type", ["str", "list", "dict", "dict_of_dict"])
def test_download_manager_download(urls_type, tmp_path, monkeypatch):
import requests

monkeypatch.setattr(requests, "request", mock_request)

url = URL
if issubclass(urls_type, str):
urls = url
elif issubclass(urls_type, list):
urls = [url]
elif issubclass(urls_type, dict):
urls = {"train": url}
urls_types = {"str": url, "list": [url], "dict": {"train": url}, "dict_of_dict": {"train": {"en": url}}}
urls = urls_types[urls_type]
dataset_name = "dummy"
cache_subdir = "downloads"
cache_dir_root = tmp_path
Expand All @@ -50,29 +47,29 @@ def test_download_manager_download(urls_type, tmp_path, monkeypatch):
)
dl_manager = DownloadManager(dataset_name=dataset_name, download_config=download_config)
downloaded_paths = dl_manager.download(urls)
input_urls = urls
for downloaded_paths in [downloaded_paths]:
if isinstance(urls, str):
downloaded_paths = [downloaded_paths]
input_urls = [urls]
elif isinstance(urls, dict):
assert "train" in downloaded_paths.keys()
downloaded_paths = downloaded_paths.values()
input_urls = urls.values()
assert downloaded_paths
for downloaded_path, input_url in zip(downloaded_paths, input_urls):
assert downloaded_path == dl_manager.downloaded_paths[input_url]
downloaded_path = Path(downloaded_path)
parts = downloaded_path.parts
assert parts[-1] == HASH
assert parts[-2] == cache_subdir
assert downloaded_path.exists()
content = downloaded_path.read_text()
assert content == CONTENT
metadata_downloaded_path = downloaded_path.with_suffix(".json")
assert metadata_downloaded_path.exists()
metadata_content = json.loads(metadata_downloaded_path.read_text())
assert metadata_content == {"url": URL, "etag": None}
assert isinstance(downloaded_paths, type(urls))
if "urls_type".startswith("list"):
assert len(downloaded_paths) == len(urls)
elif "urls_type".startswith("dict"):
assert downloaded_paths.keys() == urls.keys()
if "urls_type" == "dict_of_dict":
key = list(urls.keys())[0]
assert isinstance(downloaded_paths[key], dict)
assert downloaded_paths[key].keys() == urls[key].keys()
for downloaded_path, url in zip(
NestedDataStructure(downloaded_paths).flatten(), NestedDataStructure(urls).flatten()
):
downloaded_path = Path(downloaded_path)
parts = downloaded_path.parts
assert parts[-1] == HASH
assert parts[-2] == cache_subdir
assert downloaded_path.exists()
content = downloaded_path.read_text()
assert content == CONTENT
metadata_downloaded_path = downloaded_path.with_suffix(".json")
assert metadata_downloaded_path.exists()
metadata_content = json.loads(metadata_downloaded_path.read_text())
assert metadata_content == {"url": URL, "etag": None}


@pytest.mark.parametrize("paths_type", [str, list, dict])
Expand Down
58 changes: 24 additions & 34 deletions tests/test_py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,49 +29,39 @@ def add_one(i): # picklable for multiprocessing
return i + 1


def add_one_to_batch(batch): # picklable for multiprocessing
return [i + 1 for i in batch]


@dataclass
class A:
x: int
y: str


@pytest.mark.parametrize("batched, function", [(False, add_one), (True, add_one_to_batch)])
@pytest.mark.parametrize("num_proc", [None, 2])
@pytest.mark.parametrize(
"data_struct, expected_result",
[
({}, {}),
([], []),
(1, 2),
([1, 2], [2, 3]),
({"a": 1, "b": 2}, {"a": 2, "b": 3}),
({"a": [1, 2], "b": [3, 4]}, {"a": [2, 3], "b": [4, 5]}),
({"a": {"1": 1}, "b": {"2": 2}}, {"a": {"1": 2}, "b": {"2": 3}}),
({"a": 1, "b": [2, 3], "c": {"1": 4}}, {"a": 2, "b": [3, 4], "c": {"1": 5}}),
({"a": 1, "b": 2, "c": 3, "d": 4}, {"a": 2, "b": 3, "c": 4, "d": 5}),
],
)
def test_map_nested(data_struct, expected_result, num_proc, batched, function):
assert map_nested(function, data_struct, num_proc=num_proc, batched=batched) == expected_result


class PyUtilsTest(TestCase):
def test_map_nested(self):
s1 = {}
s2 = []
s3 = 1
s4 = [1, 2]
s5 = {"a": 1, "b": 2}
s6 = {"a": [1, 2], "b": [3, 4]}
s7 = {"a": {"1": 1}, "b": 2}
s8 = {"a": 1, "b": 2, "c": 3, "d": 4}
expected_map_nested_s1 = {}
expected_map_nested_s2 = []
expected_map_nested_s3 = 2
expected_map_nested_s4 = [2, 3]
expected_map_nested_s5 = {"a": 2, "b": 3}
expected_map_nested_s6 = {"a": [2, 3], "b": [4, 5]}
expected_map_nested_s7 = {"a": {"1": 2}, "b": 3}
expected_map_nested_s8 = {"a": 2, "b": 3, "c": 4, "d": 5}
self.assertEqual(map_nested(add_one, s1), expected_map_nested_s1)
self.assertEqual(map_nested(add_one, s2), expected_map_nested_s2)
self.assertEqual(map_nested(add_one, s3), expected_map_nested_s3)
self.assertEqual(map_nested(add_one, s4), expected_map_nested_s4)
self.assertEqual(map_nested(add_one, s5), expected_map_nested_s5)
self.assertEqual(map_nested(add_one, s6), expected_map_nested_s6)
self.assertEqual(map_nested(add_one, s7), expected_map_nested_s7)
self.assertEqual(map_nested(add_one, s8), expected_map_nested_s8)

num_proc = 2
self.assertEqual(map_nested(add_one, s1, num_proc=num_proc), expected_map_nested_s1)
self.assertEqual(map_nested(add_one, s2, num_proc=num_proc), expected_map_nested_s2)
self.assertEqual(map_nested(add_one, s3, num_proc=num_proc), expected_map_nested_s3)
self.assertEqual(map_nested(add_one, s4, num_proc=num_proc), expected_map_nested_s4)
self.assertEqual(map_nested(add_one, s5, num_proc=num_proc), expected_map_nested_s5)
self.assertEqual(map_nested(add_one, s6, num_proc=num_proc), expected_map_nested_s6)
self.assertEqual(map_nested(add_one, s7, num_proc=num_proc), expected_map_nested_s7)
self.assertEqual(map_nested(add_one, s8, num_proc=num_proc), expected_map_nested_s8)

sn1 = {"a": np.eye(2), "b": np.zeros(3), "c": np.ones(2)}
expected_map_nested_sn1_sum = {"a": 2, "b": 0, "c": 2}
expected_map_nested_sn1_int = {
Expand Down

0 comments on commit a5a76a4

Please sign in to comment.