Skip to content

Commit

Permalink
Fix push_to_hub by not calling create_branch if branch exists (#7069)
Browse files Browse the repository at this point in the history
* Fix push_to_hub by not calling create_branch if branch exists

* Fix push_to_hub by not calling create_branch if branch exists

* Reword comment

* Fix push_to_hub by not calling create_branch if PR ref

* Update test
  • Loading branch information
albertvillanova committed Aug 14, 2024
1 parent d690c4f commit 1c9870d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5677,7 +5677,8 @@ def push_to_hub(
)
repo_id = repo_url.repo_id

if revision is not None:
if revision is not None and not revision.startswith("refs/pr/"):
# We do not call create_branch for a PR reference: 400 Bad Request
api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True)

if not data_dir:
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,8 @@ def push_to_hub(
)
repo_id = repo_url.repo_id

if revision is not None:
if revision is not None and not revision.startswith("refs/pr/"):
# We do not call create_branch for a PR reference: 400 Bad Request
api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True)

if not data_dir:
Expand Down
5 changes: 2 additions & 3 deletions tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,8 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_
_ = convert_to_parquet(repo_id, token=hf_token, trust_remote_code=True)
# mock_create_branch
assert mock_create_branch.called
assert mock_create_branch.call_count == 2
for call_args, expected_branch in zip(mock_create_branch.call_args_list, ["refs/pr/1", "script"]):
assert call_args.kwargs.get("branch") == expected_branch
assert mock_create_branch.call_count == 1
assert mock_create_branch.call_args.kwargs.get("branch") == "script"
# mock_create_commit
assert mock_create_commit.called
assert mock_create_commit.call_count == 2
Expand Down

0 comments on commit 1c9870d

Please sign in to comment.