diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 7ba052d3fde..25b27d091a3 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -5677,7 +5677,8 @@ def push_to_hub( ) repo_id = repo_url.repo_id - if revision is not None: + if revision is not None and not revision.startswith("refs/pr/"): + # We do not call create_branch for a PR reference: 400 Bad Request api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True) if not data_dir: diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 5d2d9dcd9ff..cf4a6cc98f8 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1708,7 +1708,8 @@ def push_to_hub( ) repo_id = repo_url.repo_id - if revision is not None: + if revision is not None and not revision.startswith("refs/pr/"): + # We do not call create_branch for a PR reference: 400 Bad Request api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True) if not data_dir: diff --git a/tests/test_hub.py b/tests/test_hub.py index ab766d01779..9485fe83a71 100644 --- a/tests/test_hub.py +++ b/tests/test_hub.py @@ -66,9 +66,8 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_ _ = convert_to_parquet(repo_id, token=hf_token, trust_remote_code=True) # mock_create_branch assert mock_create_branch.called - assert mock_create_branch.call_count == 2 - for call_args, expected_branch in zip(mock_create_branch.call_args_list, ["refs/pr/1", "script"]): - assert call_args.kwargs.get("branch") == expected_branch + assert mock_create_branch.call_count == 1 + assert mock_create_branch.call_args.kwargs.get("branch") == "script" # mock_create_commit assert mock_create_commit.called assert mock_create_commit.call_count == 2