Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
olupton committed Aug 21, 2024
1 parent 4e181f9 commit dfab493
Showing 1 changed file with 3 additions and 15 deletions.
18 changes: 3 additions & 15 deletions .github/container/jax-nccl-test
Original file line number Diff line number Diff line change
Expand Up @@ -178,21 +178,9 @@ if __name__ == "__main__":
if input.size < max_elements:
# TODO: make this sensitive to whether the permutation does or
# does not cross NVLink domain boundaries
if False:
node_size = 8
permutation = [
(
node_size * rank + local_device,
node_size * rank + (local_device + 1) % node_size,
)
for rank in range(n_devices // 8)
for local_device in range(8)
]
else:
permutation = [
(i, (i + 1) % n_devices) for i in range(n_devices)
]

permutation = [
(i, (i + 1) % n_devices) for i in range(n_devices)
]
result = jax.lax.ppermute(input, "i", permutation)
assert result.shape == (1, values_per_device), result.shape
else:
Expand Down

0 comments on commit dfab493

Please sign in to comment.