Skip to content

Commit

Permalink
Fix prepare_single_hop_path_and_storage_options (#7068)
Browse files Browse the repository at this point in the history
* Transform all HF HTTP URLs to HF protocol

* Fix test URL

* Remove HF headers for non-HF HTTP URLs

* Fix for HTTP storage_options without 'headers'

* Remove unused cookies

* Refactor

* Refactor list to set to check membership

* Refactor to add protocol key to storage_options only at the end

* Fix overwriting storage_options nested values

* Add tests

* Revert "Transform all HF HTTP URLs to HF protocol"

This reverts commit a337212.

* Test that DownloadConfig.storage_options are not modified

* Fix so DownloadConfig.storage_options are not modified

* Refactor fix

* Test also GitHub URL

* Fix DownloadConfig.storage_options for GitHub URL
  • Loading branch information
albertvillanova authored Jul 29, 2024
1 parent 30000fb commit baea190
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 23 deletions.
36 changes: 15 additions & 21 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ def _prepare_single_hop_path_and_storage_options(
urlpath = "hf://" + urlpath[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1)
protocol = urlpath.split("://")[0] if "://" in urlpath else "file"
if download_config is not None and protocol in download_config.storage_options:
storage_options = download_config.storage_options[protocol]
storage_options = download_config.storage_options[protocol].copy()
elif download_config is not None and protocol not in download_config.storage_options:
storage_options = {
option_name: option_value
Expand All @@ -1159,40 +1159,34 @@ def _prepare_single_hop_path_and_storage_options(
}
else:
storage_options = {}
if storage_options:
storage_options = {protocol: storage_options}
if protocol in ["http", "https"]:
storage_options[protocol] = {
"headers": {
**get_authentication_headers_for_url(urlpath, token=token),
"user-agent": get_datasets_user_agent(),
},
"client_kwargs": {"trust_env": True}, # Enable reading proxy env variables.
**(storage_options.get(protocol, {})),
}
if protocol in {"http", "https"}:
client_kwargs = storage_options.pop("client_kwargs", {})
storage_options["client_kwargs"] = {"trust_env": True, **client_kwargs} # Enable reading proxy env variables
if "drive.google.com" in urlpath:
response = http_head(urlpath)
cookies = None
for k, v in response.cookies.items():
if k.startswith("download_warning"):
urlpath += "&confirm=" + v
cookies = response.cookies
storage_options[protocol] = {"cookies": cookies, **storage_options.get(protocol, {})}
# Fix Google Drive URL to avoid Virus scan warning
if "drive.google.com" in urlpath and "confirm=" not in urlpath:
urlpath += "&confirm=t"
storage_options = {"cookies": cookies, **storage_options}
# Fix Google Drive URL to avoid Virus scan warning
if "confirm=" not in urlpath:
urlpath += "&confirm=t"
if urlpath.startswith("https://raw.githubusercontent.com/"):
# Workaround for served data with gzip content-encoding: https://github.com/fsspec/filesystem_spec/issues/389
storage_options[protocol]["headers"]["Accept-Encoding"] = "identity"
headers = storage_options.pop("headers", {})
storage_options["headers"] = {"Accept-Encoding": "identity", **headers}
elif protocol == "hf":
storage_options[protocol] = {
storage_options = {
"token": token,
"endpoint": config.HF_ENDPOINT,
**storage_options.get(protocol, {}),
**storage_options,
}
# streaming with block_size=0 is only implemented in 0.21 (see https://github.com/huggingface/huggingface_hub/pull/1967)
if config.HF_HUB_VERSION < version.parse("0.21.0"):
storage_options[protocol]["block_size"] = "default"
storage_options["block_size"] = "default"
if storage_options:
storage_options = {protocol: storage_options}
return urlpath, storage_options


Expand Down
73 changes: 71 additions & 2 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datasets.utils.file_utils import (
OfflineModeIsEnabled,
_get_extraction_protocol,
_prepare_single_hop_path_and_storage_options,
cached_path,
fsspec_get,
fsspec_head,
Expand Down Expand Up @@ -47,7 +48,7 @@

FILE_PATH = "file"

TEST_URL = "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/raw/main/some_text.txt"
TEST_URL = "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt"
TEST_URL_CONTENT = "foo\nbar\nfoobar"

TEST_GG_DRIVE_FILENAME = "train.tsv"
Expand Down Expand Up @@ -90,7 +91,6 @@ def test_cached_path_protocols(protocol, monkeypatch, tmp_path):
urls = {"hf": "hf://datasets/org-name/ds-name@main/filename.ext", "s3": "s3://bucket-name/filename.ext"}
url = urls[protocol]
_ = cached_path(url, download_config=download_config)
assert True
for mock in [mock_fsspec_head, mock_fsspec_get]:
assert mock.called
assert mock.call_count == 1
Expand Down Expand Up @@ -197,6 +197,75 @@ def test_fsspec_offline(tmp_path_factory):
fsspec_head("s3://huggingface.co")


@pytest.mark.parametrize(
"urlpath, download_config, expected_urlpath, expected_storage_options",
[
(
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
DownloadConfig(),
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
{"hf": {"endpoint": "https://huggingface.co", "token": None}},
),
(
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
DownloadConfig(token="MY-TOKEN"),
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
{"hf": {"endpoint": "https://huggingface.co", "token": "MY-TOKEN"}},
),
(
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
DownloadConfig(token="MY-TOKEN", storage_options={"hf": {"on_error": "omit"}}),
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
{"hf": {"endpoint": "https://huggingface.co", "token": "MY-TOKEN", "on_error": "omit"}},
),
(
"https://domain.org/data.txt",
DownloadConfig(),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": True}}},
),
(
"https://domain.org/data.txt",
DownloadConfig(storage_options={"https": {"block_size": "omit"}}),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": True}, "block_size": "omit"}},
),
(
"https://domain.org/data.txt",
DownloadConfig(storage_options={"https": {"client_kwargs": {"raise_for_status": True}}}),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": True, "raise_for_status": True}}},
),
(
"https://domain.org/data.txt",
DownloadConfig(storage_options={"https": {"client_kwargs": {"trust_env": False}}}),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": False}}},
),
(
"https://raw.githubusercontent.com/data.txt",
DownloadConfig(storage_options={"https": {"headers": {"x-test": "true"}}}),
"https://raw.githubusercontent.com/data.txt",
{
"https": {
"client_kwargs": {"trust_env": True},
"headers": {"x-test": "true", "Accept-Encoding": "identity"},
}
},
),
],
)
def test_prepare_single_hop_path_and_storage_options(
urlpath, download_config, expected_urlpath, expected_storage_options
):
original_download_config_storage_options = str(download_config.storage_options)
prepared_urlpath, storage_options = _prepare_single_hop_path_and_storage_options(urlpath, download_config)
assert prepared_urlpath == expected_urlpath
assert storage_options == expected_storage_options
# Check that DownloadConfig.storage_options are not modified:
assert str(download_config.storage_options) == original_download_config_storage_options


class DummyTestFS(AbstractFileSystem):
protocol = "mock"
_file_class = AbstractBufferedFile
Expand Down

0 comments on commit baea190

Please sign in to comment.