From dfab493a7bb334b27a53cfe5bd0be760733de848 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Wed, 21 Aug 2024 09:34:41 +0000 Subject: [PATCH] cleanup --- .github/container/jax-nccl-test | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/.github/container/jax-nccl-test b/.github/container/jax-nccl-test index 6a5543fc9..706713baf 100755 --- a/.github/container/jax-nccl-test +++ b/.github/container/jax-nccl-test @@ -178,21 +178,9 @@ if __name__ == "__main__": if input.size < max_elements: # TODO: make this sensitive to whether the permutation does or # does not cross NVLink domain boundaries - if False: - node_size = 8 - permutation = [ - ( - node_size * rank + local_device, - node_size * rank + (local_device + 1) % node_size, - ) - for rank in range(n_devices // 8) - for local_device in range(8) - ] - else: - permutation = [ - (i, (i + 1) % n_devices) for i in range(n_devices) - ] - + permutation = [ + (i, (i + 1) % n_devices) for i in range(n_devices) + ] result = jax.lax.ppermute(input, "i", permutation) assert result.shape == (1, values_per_device), result.shape else: