Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix prepare_single_hop_path_and_storage_options #7068

Merged
merged 16 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ def _prepare_single_hop_path_and_storage_options(
urlpath = "hf://" + urlpath[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1)
protocol = urlpath.split("://")[0] if "://" in urlpath else "file"
if download_config is not None and protocol in download_config.storage_options:
storage_options = download_config.storage_options[protocol]
storage_options = download_config.storage_options[protocol].copy()
elif download_config is not None and protocol not in download_config.storage_options:
storage_options = {
option_name: option_value
Expand All @@ -1159,40 +1159,34 @@ def _prepare_single_hop_path_and_storage_options(
}
else:
storage_options = {}
if storage_options:
storage_options = {protocol: storage_options}
if protocol in ["http", "https"]:
storage_options[protocol] = {
"headers": {
**get_authentication_headers_for_url(urlpath, token=token),
"user-agent": get_datasets_user_agent(),
},
"client_kwargs": {"trust_env": True}, # Enable reading proxy env variables.
**(storage_options.get(protocol, {})),
}
if protocol in {"http", "https"}:
client_kwargs = storage_options.pop("client_kwargs", {})
storage_options["client_kwargs"] = {"trust_env": True, **client_kwargs} # Enable reading proxy env variables
if "drive.google.com" in urlpath:
response = http_head(urlpath)
cookies = None
for k, v in response.cookies.items():
if k.startswith("download_warning"):
urlpath += "&confirm=" + v
cookies = response.cookies
storage_options[protocol] = {"cookies": cookies, **storage_options.get(protocol, {})}
# Fix Google Drive URL to avoid Virus scan warning
if "drive.google.com" in urlpath and "confirm=" not in urlpath:
urlpath += "&confirm=t"
storage_options = {"cookies": cookies, **storage_options}
# Fix Google Drive URL to avoid Virus scan warning
if "confirm=" not in urlpath:
urlpath += "&confirm=t"
if urlpath.startswith("https://raw.githubusercontent.com/"):
# Workaround for served data with gzip content-encoding: https://github.com/fsspec/filesystem_spec/issues/389
storage_options[protocol]["headers"]["Accept-Encoding"] = "identity"
headers = storage_options.pop("headers", {})
storage_options["headers"] = {"Accept-Encoding": "identity", **headers}
elif protocol == "hf":
storage_options[protocol] = {
storage_options = {
"token": token,
"endpoint": config.HF_ENDPOINT,
**storage_options.get(protocol, {}),
**storage_options,
}
# streaming with block_size=0 is only implemented in 0.21 (see https://github.com/huggingface/huggingface_hub/pull/1967)
if config.HF_HUB_VERSION < version.parse("0.21.0"):
storage_options[protocol]["block_size"] = "default"
storage_options["block_size"] = "default"
if storage_options:
storage_options = {protocol: storage_options}
return urlpath, storage_options


Expand Down
73 changes: 71 additions & 2 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datasets.utils.file_utils import (
OfflineModeIsEnabled,
_get_extraction_protocol,
_prepare_single_hop_path_and_storage_options,
cached_path,
fsspec_get,
fsspec_head,
Expand Down Expand Up @@ -47,7 +48,7 @@

FILE_PATH = "file"

TEST_URL = "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/raw/main/some_text.txt"
TEST_URL = "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt"
TEST_URL_CONTENT = "foo\nbar\nfoobar"

TEST_GG_DRIVE_FILENAME = "train.tsv"
Expand Down Expand Up @@ -90,7 +91,6 @@ def test_cached_path_protocols(protocol, monkeypatch, tmp_path):
urls = {"hf": "hf://datasets/org-name/ds-name@main/filename.ext", "s3": "s3://bucket-name/filename.ext"}
url = urls[protocol]
_ = cached_path(url, download_config=download_config)
assert True
for mock in [mock_fsspec_head, mock_fsspec_get]:
assert mock.called
assert mock.call_count == 1
Expand Down Expand Up @@ -197,6 +197,75 @@ def test_fsspec_offline(tmp_path_factory):
fsspec_head("s3://huggingface.co")


@pytest.mark.parametrize(
"urlpath, download_config, expected_urlpath, expected_storage_options",
[
(
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
DownloadConfig(),
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
{"hf": {"endpoint": "https://huggingface.co", "token": None}},
),
(
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
DownloadConfig(token="MY-TOKEN"),
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
{"hf": {"endpoint": "https://huggingface.co", "token": "MY-TOKEN"}},
),
(
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
DownloadConfig(token="MY-TOKEN", storage_options={"hf": {"on_error": "omit"}}),
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
{"hf": {"endpoint": "https://huggingface.co", "token": "MY-TOKEN", "on_error": "omit"}},
),
(
"https://domain.org/data.txt",
DownloadConfig(),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": True}}},
),
(
"https://domain.org/data.txt",
DownloadConfig(storage_options={"https": {"block_size": "omit"}}),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": True}, "block_size": "omit"}},
),
(
"https://domain.org/data.txt",
DownloadConfig(storage_options={"https": {"client_kwargs": {"raise_for_status": True}}}),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": True, "raise_for_status": True}}},
),
(
"https://domain.org/data.txt",
DownloadConfig(storage_options={"https": {"client_kwargs": {"trust_env": False}}}),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": False}}},
),
(
"https://raw.githubusercontent.com/data.txt",
DownloadConfig(storage_options={"https": {"headers": {"x-test": "true"}}}),
"https://raw.githubusercontent.com/data.txt",
{
"https": {
"client_kwargs": {"trust_env": True},
"headers": {"x-test": "true", "Accept-Encoding": "identity"},
}
},
),
],
)
def test_prepare_single_hop_path_and_storage_options(
urlpath, download_config, expected_urlpath, expected_storage_options
):
original_download_config_storage_options = str(download_config.storage_options)
prepared_urlpath, storage_options = _prepare_single_hop_path_and_storage_options(urlpath, download_config)
assert prepared_urlpath == expected_urlpath
assert storage_options == expected_storage_options
# Check that DownloadConfig.storage_options are not modified:
assert str(download_config.storage_options) == original_download_config_storage_options


class DummyTestFS(AbstractFileSystem):
protocol = "mock"
_file_class = AbstractBufferedFile
Expand Down
Loading