Skip to content

Commit

Permalink
@gspschmid suggestions from JAX-Toolbox#863
Browse files Browse the repository at this point in the history
  • Loading branch information
olupton committed Jun 4, 2024
1 parent 775e10d commit e28ad9b
Showing 1 changed file with 36 additions and 26 deletions.
62 changes: 36 additions & 26 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
pd.options.mode.copy_on_write = True


def element_type_in_bits(element_type: int) -> int:
def element_type_width(element_type: int) -> int:
"""
Given an int representing an XLA PrimitiveType enum value, return the width of that
type in bits.
Expand All @@ -31,8 +31,32 @@ def element_type_in_bits(element_type: int) -> int:
raise Exception(f"Could not deduce size of {enum_name}")


def _collective_correction(kind: str, size: int) -> tuple[float, float]:
"""
Calculate the correction factor from algorithm bandwidth to bus bandwidth, see:
https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#bus-bandwidth
"""
match kind:
# For AllGather the size in the bandwidth calculation is the total/output size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#allgather
case "all-gather":
return (size, (size - 1) / size)
case "all-reduce":
return (1, 2 * (size - 1) / size)
case "collective-broadcast":
return (1, 1)
case "collective-permute":
return (1, 1)
# For ReduceScatter the size in the bandwidth calculation is the total size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#reducescatter
case "reduce-scatter":
return (size, (size - 1) / size)
case _:
assert False, f"Unknown collective kind {kind}"


@functools.lru_cache
def get_message_size(program_id: int, instruction: str) -> int:
def get_message_size(program_id: int, instruction: str) -> pd.Series:
"""
Given the name of a collective instruction (e.g. all-gather-start.N), calculate the
message size in bytes. See https://openxla.org/xla/operation_semantics#allgather,
Expand All @@ -59,42 +83,28 @@ def get_message_size(program_id: int, instruction: str) -> int:
else:
# replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8
# devices that are doing pair-wise collectives
collective_sizes = tuple(
{len(group.replica_ids) for group in inst.replica_groups}
)
assert (
len(collective_sizes) == 1
collective_size = len(inst.replica_groups[0].replica_ids)
assert all(
len(group.replica_ids) == collective_size for group in inst.replica_groups
), f"Heterogeneous collective {inst.replica_groups} could not be interpreted"
collective_size = collective_sizes[0]
total_msg_size = 0
for operand_id in inst.operand_ids:
_, operand = module_proto.find_instruction_by_id(operand_id)
msg_size_bits = math.prod(
operand.shape.dimensions,
start=element_type_in_bits(operand.shape.element_type),
start=element_type_width(operand.shape.element_type),
)
if inst.opcode == "reduce-scatter":
# NCCL's convention is that the message size of a reduce-scatter is the size of output buffer:
# https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122
assert msg_size_bits % collective_size == 0
msg_size_bits //= collective_size
assert msg_size_bits % 8 == 0
total_msg_size += msg_size_bits // 8
msg_size_bits, rem = divmod(msg_size_bits, collective_size)
assert rem == 0
msg_size_bytes, rem = divmod(msg_size_bits, 8)
assert rem == 0
total_msg_size += msg_size_bytes

# Calculate the correction factor from algorithm bandwidth to bus bandwidth, see:
# https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#bus-bandwidth
collective = inst.opcode.removesuffix("-start")
bw_correction, bus_correction = {
# For AllGather the size in the bandwidth calculation is the total/output size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#allgather
"all-gather": (collective_size, (collective_size - 1) / collective_size),
"all-reduce": (1, 2 * (collective_size - 1) / collective_size),
"collective-broadcast": (1, 1),
"collective-permute": (1, 1),
# For ReduceScatter the size in the bandwidth calculation is the total size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#reducescatter
"reduce-scatter": (collective_size, (collective_size - 1) / collective_size),
}[collective]
bw_correction, bus_correction = _collective_correction(collective, collective_size)
return pd.Series(
[total_msg_size, collective, collective_size, bw_correction, bus_correction],
index=[
Expand Down

0 comments on commit e28ad9b

Please sign in to comment.