From e28ad9b1b564ac8ab81b93d0b6ba275a8cbcdd77 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Mon, 3 Jun 2024 16:07:01 +0200 Subject: [PATCH] @gspschmid suggestions from JAX-Toolbox#863 --- .../python/jax_nsys/jax_nsys/analysis.py | 62 +++++++++++-------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py index 73ef2b60d..4b30af181 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py @@ -10,7 +10,7 @@ pd.options.mode.copy_on_write = True -def element_type_in_bits(element_type: int) -> int: +def element_type_width(element_type: int) -> int: """ Given an int representing an XLA PrimitiveType enum value, return the width of that type in bits. @@ -31,8 +31,32 @@ def element_type_in_bits(element_type: int) -> int: raise Exception(f"Could not deduce size of {enum_name}") +def _collective_correction(kind: str, size: int) -> tuple[float, float]: + """ + Calculate the correction factor from algorithm bandwidth to bus bandwidth, see: + https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#bus-bandwidth + """ + match kind: + # For AllGather the size in the bandwidth calculation is the total/output size + # https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#allgather + case "all-gather": + return (size, (size - 1) / size) + case "all-reduce": + return (1, 2 * (size - 1) / size) + case "collective-broadcast": + return (1, 1) + case "collective-permute": + return (1, 1) + # For ReduceScatter the size in the bandwidth calculation is the total size + # https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#reducescatter + case "reduce-scatter": + return (size, (size - 1) / size) + case _: + assert False, f"Unknown collective kind {kind}" + + @functools.lru_cache -def get_message_size(program_id: int, instruction: str) -> int: +def get_message_size(program_id: int, instruction: str) -> pd.Series: """ Given the name of a collective instruction (e.g. all-gather-start.N), calculate the message size in bytes. See https://openxla.org/xla/operation_semantics#allgather, @@ -59,42 +83,28 @@ def get_message_size(program_id: int, instruction: str) -> int: else: # replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8 # devices that are doing pair-wise collectives - collective_sizes = tuple( - {len(group.replica_ids) for group in inst.replica_groups} - ) - assert ( - len(collective_sizes) == 1 + collective_size = len(inst.replica_groups[0].replica_ids) + assert all( + len(group.replica_ids) == collective_size for group in inst.replica_groups ), f"Heterogeneous collective {inst.replica_groups} could not be interpreted" - collective_size = collective_sizes[0] total_msg_size = 0 for operand_id in inst.operand_ids: _, operand = module_proto.find_instruction_by_id(operand_id) msg_size_bits = math.prod( operand.shape.dimensions, - start=element_type_in_bits(operand.shape.element_type), + start=element_type_width(operand.shape.element_type), ) if inst.opcode == "reduce-scatter": # NCCL's convention is that the message size of a reduce-scatter is the size of output buffer: # https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122 - assert msg_size_bits % collective_size == 0 - msg_size_bits //= collective_size - assert msg_size_bits % 8 == 0 - total_msg_size += msg_size_bits // 8 + msg_size_bits, rem = divmod(msg_size_bits, collective_size) + assert rem == 0 + msg_size_bytes, rem = divmod(msg_size_bits, 8) + assert rem == 0 + total_msg_size += msg_size_bytes - # Calculate the correction factor from algorithm bandwidth to bus bandwidth, see: - # https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#bus-bandwidth collective = inst.opcode.removesuffix("-start") - bw_correction, bus_correction = { - # For AllGather the size in the bandwidth calculation is the total/output size - # https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#allgather - "all-gather": (collective_size, (collective_size - 1) / collective_size), - "all-reduce": (1, 2 * (collective_size - 1) / collective_size), - "collective-broadcast": (1, 1), - "collective-permute": (1, 1), - # For ReduceScatter the size in the bandwidth calculation is the total size - # https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#reducescatter - "reduce-scatter": (collective_size, (collective_size - 1) / collective_size), - }[collective] + bw_correction, bus_correction = _collective_correction(collective, collective_size) return pd.Series( [total_msg_size, collective, collective_size, bw_correction, bus_correction], index=[