Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix ccl op signature mismatch when updating torch to 2.4 or above #474

Merged
merged 2 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 112 additions & 3 deletions frontends/torch-frontend/third_party/patches/fx_importer.patch
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py
index 91d81de0..6f5e041f 100644
index 99c8d3cf..1fc2a5e1 100644
--- a/python/torch_mlir/extras/fx_importer.py
+++ b/python/torch_mlir/extras/fx_importer.py
@@ -54,6 +54,10 @@ from torch._subclasses import (
Expand All @@ -13,7 +13,116 @@ index 91d81de0..6f5e041f 100644
from torch.fx import (
Graph,
GraphModule,
@@ -2096,6 +2100,8 @@ def _make_vtensor_literal_op(
@@ -1614,6 +1618,61 @@ class GraphNodeImporter:
for i, value in enumerate(operation.results):
self.bind_node_value(node, value, i + bind_none)

+ def _import_torch_c10d_functional_op_overload(
+ self,
+ node: torch_fx.node,
+ schema,
+ loc: Location,
+ ):
+ import torch.distributed.distributed_c10d as c10d
+ def resolve_group_name(group_name: str) -> Tuple[str, List[int], int]:
+ group = torch._C._distributed_c10d._resolve_process_group(group_name)
+ group_rank = group.rank()
+ group_size = group.size()
+ global_group_ranks = c10d.get_process_group_ranks(group)
+ return group_name, global_group_ranks, group_size
+
+ operands = []
+ group_size, tag = None, None,
+ for i, parameter in enumerate(schema.arguments):
+ if parameter.name == "group_name":
+ if i < len(node.args):
+ group_name = node.args[i]
+ else:
+ assert parameter.name in node.kwargs
+ group_name = node.kwargs[parameter.name]
+ tmp_tag, global_global_ranks, tmp_group_size = resolve_group_name(group_name)
+ if group_size is None:
+ group_size = tmp_group_size
+ if tag is None:
+ tag = tmp_tag
+ global_global_ranks = torch_fx.immutable_collections.immutable_list(global_global_ranks)
+ operands.append(self._import_argument(loc, tag, str))
+ operands.append(self._import_argument(loc, global_global_ranks, torch.ListType.ofInts()))
+ operands.append(self._import_argument(loc, group_size, int))
+ elif parameter.name == "group_size":
+ group_size = node.args[i] if i < len(node.args) else node.kwargs["group_size"]
+ elif parameter.name == "tag":
+ tag = node.args[i] if i < len(node.args) else node.kwargs["tag"]
+ else:
+ if i < len(node.args):
+ operands.append(
+ self._import_argument(loc, node.args[i], parameter.type)
+ )
+ elif parameter.name in node.kwargs:
+ operands.append(
+ self._import_argument(
+ loc, node.kwargs[parameter.name], parameter.type
+ )
+ )
+ else:
+ operands.append(
+ self._import_default_value(
+ loc, parameter.default_value, parameter.type
+ )
+ )
+ return operands
+
def _import_torch_op_overload(
self,
loc: Location,
@@ -1655,24 +1714,30 @@ class GraphNodeImporter:
self._multi_result_nodes.add(node)

# Unroll operands from formal parameters, args and kwargs.
- operands = []
- for i, parameter in enumerate(schema.arguments):
- if i < len(node.args):
- operands.append(
- self._import_argument(loc, node.args[i], parameter.type)
- )
- elif parameter.name in node.kwargs:
- operands.append(
- self._import_argument(
- loc, node.kwargs[parameter.name], parameter.type
+ if "c10d_functional" in mlir_op_name:
+ # Since pytorch has two sets of collective operators defined in different OpNamespaces,
+ # we are enforcing a unified one.
+ mlir_op_name = mlir_op_name.replace("_c10d_functional", "c10d_functional")
+ operands = self._import_torch_c10d_functional_op_overload(node, schema, loc)
+ else:
+ operands = []
+ for i, parameter in enumerate(schema.arguments):
+ if i < len(node.args):
+ operands.append(
+ self._import_argument(loc, node.args[i], parameter.type)
)
- )
- else:
- operands.append(
- self._import_default_value(
- loc, parameter.default_value, parameter.type
+ elif parameter.name in node.kwargs:
+ operands.append(
+ self._import_argument(
+ loc, node.kwargs[parameter.name], parameter.type
+ )
+ )
+ else:
+ operands.append(
+ self._import_default_value(
+ loc, parameter.default_value, parameter.type
+ )
)
- )

operation = _emit_operation(
mlir_op_name, result_types=result_types, operands=operands, loc=loc
@@ -2057,6 +2122,8 @@ def _make_vtensor_literal_op(
) -> Operation:
mapping = py_attr_tracker.track(tensor)
if mapping.is_empty:
Expand All @@ -22,7 +131,7 @@ index 91d81de0..6f5e041f 100644
# check support for bfloat16
assert not (
tensor.dtype == torch.bfloat16 and ml_dtypes is None
@@ -2111,11 +2117,17 @@ def _make_vtensor_literal_op(
@@ -2072,11 +2139,17 @@ def _make_vtensor_literal_op(
# detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import torch
from torch.testing import FileCheck

import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch.testing._internal.common_utils import run_tests

from utils import with_comms, DistributedTestBase

import torch_frontend
from torch_frontend import compile_dynamo_model


class AllReduceModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return funcol.all_reduce(x, "sum", [0, 1, 2, 3])


class AllGatherModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return funcol.all_gather_tensor(x, 0, [0, 1, 2, 3])


class ReduceScatterModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return funcol.reduce_scatter_tensor(x, "sum", 0, [0, 1, 2, 3])


class BroadcastModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x):
return funcol.broadcast(x, 2, [0, 1, 2, 3])


class DistributedCollectiveTest(DistributedTestBase):
@property
def world_size(self):
return 4

@with_comms
def test_reduce_scatter(self):
module = ReduceScatterModule()
inputs = [torch.tensor([1, 2, 3, 4], dtype=torch.float32)]
prog = torch.export.export(module, tuple(inputs))
if dist.get_rank() == 0:
module = compile_dynamo_model(prog, "stablehlo")
ir = module.operation.get_asm()
FileCheck().check("@main").check("ccl.reduce_scatter").check(
"axis = 0"
).check('reduction = "sum"').check("replica_groups = [[0, 1, 2, 3]]").check(
"-> tensor<1xf32>"
).run(
ir
)

@with_comms
def test_all_reduce(self):
module = AllReduceModule()
inputs = [torch.tensor([1, 2, 3, 4], dtype=torch.float32)]
prog = torch.export.export(module, tuple(inputs))
if dist.get_rank() == 0:
module = compile_dynamo_model(prog, "stablehlo")
ir = module.operation.get_asm()
FileCheck().check("@main").check("ccl.all_reduce").check(
'reduction = "sum"'
).check("replica_groups = [[0, 1, 2, 3]]").check("-> tensor<4xf32>").run(ir)

@with_comms
def test_all_gather(self):
module = AllGatherModule()
inputs = [torch.tensor([1, 2, 3, 4], dtype=torch.float32)]
prog = torch.export.export(module, tuple(inputs))
if dist.get_rank() == 0:
module = compile_dynamo_model(prog, "stablehlo")
ir = module.operation.get_asm()
FileCheck().check("@main").check("ccl.all_gather").check("axis = 0").check(
"replica_groups = [[0, 1, 2, 3]]"
).check("-> tensor<16xf32>").run(ir)

@with_comms
def test_broadcast(self):
module = BroadcastModule()
inputs = [torch.tensor([1, 2, 3, 4], dtype=torch.float32)]
prog = torch.export.export(module, tuple(inputs))
if dist.get_rank() == 0:
module = compile_dynamo_model(prog, "stablehlo")
ir = module.operation.get_asm()
FileCheck().check("@main").check("ccl.broadcast").check(
"replica_groups = [[2, 0, 1, 3]]"
).check("-> tensor<4xf32>").run(ir)

# TODO: add test for send/recv


class MLP(torch.nn.Module):
def __init__(self, hidden_dim, world_size):
super().__init__()
self.hidden_dim = hidden_dim
self.world_size = world_size
self.fc1 = torch.nn.Linear(self.hidden_dim, self.hidden_dim * 4)
self.fc2 = torch.nn.Linear(self.hidden_dim * 4, self.hidden_dim)

def forward(self, x):
return funcol.all_reduce(
self.fc2(self.fc1(x)), "sum", list(range(self.world_size))
)


class DistributedCollectiveE2ETest(DistributedTestBase):
@property
def world_size(self):
return 4

@with_comms
def test_mlp_e2e(self):
module = MLP(hidden_dim=4, world_size=self.world_size)
x = torch.rand(3, 4)
prog = torch.export.export(module, (x,))

module = compile_dynamo_model(prog, "stablehlo")

if dist.get_rank() == 0:
ir = module.operation.get_asm()
print(ir)


if __name__ == "__main__":
run_tests()
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import datetime
import sys
from typing import Any, Callable, Dict, Tuple, TypeVar, cast
from functools import wraps

import torch
import torch.distributed as dist
from torch.testing._internal.common_distributed import TEST_SKIPS, MultiProcessTestCase, skip_if_lt_x_gpu, TestSkip

# add new skipped test exit code
TEST_SKIPS["torch-version-2.2"] = TestSkip(90, "Need torch version bigger than 2.2")

TestFunc = Callable[[object], object]
T = TypeVar("T")
DEVICE_TYPE = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else "cpu"
PG_BACKEND = "nccl" if DEVICE_TYPE == "cuda" else "gloo"

NUM_DEVICES = 4

# We use this as a proxy for "multiple GPUs exist"
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
# when we actually have multiple GPUs, relax the requirement to smaller counts.
NUM_DEVICES = min(NUM_DEVICES, torch.cuda.device_count())


class DistributedTestBase(MultiProcessTestCase):
@property
def world_size(self) -> int:
return NUM_DEVICES

@property
def backend(self) -> str:
return PG_BACKEND

def init_pg(self) -> None:
if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)

if self.backend not in ["nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl", "meta"]:
raise RuntimeError(f"Backend {self.backend} not supported!")

dist.init_process_group(
backend=self.backend,
world_size=self.world_size,
rank=self.rank, # pyre-ignore[16]
init_method=f"file://{self.file_name}", # pyre-ignore[16]
timeout=datetime.timedelta(seconds=1200),
)

# set device for nccl pg for collectives
if "nccl" in self.backend:
torch.cuda.set_device(self.rank)

def destroy_pg(self) -> None:
# Wait for all ranks to reach here before starting shutdown.
# FIXME dist.barrier deadlocks with multiple threads and NCCL: https://github.com/pytorch/pytorch/issues/95895
# dist.all_reduce(torch.zeros((1,), device="cuda" if torch.cuda.is_available() else "cpu"))
# FIXME can't use the above all_reduce as it causes hangs on bionic and focal. It hangs:
# test_dtensor.py -- DTensorMeshTest.test_dtensor_device_mesh_device_conversion
dist.barrier()
dist.destroy_process_group()

def setUp(self) -> None:
super().setUp()
self._spawn_processes()


# wrapper to initialize comms (processgroup)
def with_comms(func: TestFunc) -> TestFunc:
assert func is not None

@wraps(func) # pyre-ignore[6]
def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None: # type: ignore[misc]
# if backend not specified, and cuda available, then use nccl, else gloo
if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size:
self.device_type = "cuda"
else:
self.device_type = "cpu"

self.init_pg()
func(self, *args, **kwargs) # type: ignore[misc]
self.destroy_pg()

return wrapper


def skip_unless_torch_gpu(method: T) -> T:
"""
Test decorator which skips the test unless there's a GPU available to torch.

>>> # xdoctest: +SKIP
>>> @skip_unless_torch_gpu
>>> def test_some_method(self) -> None:
>>> ...
"""
# The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set.
return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method))


def skip_unless_torch_version_bigger_than(torch_version: str):
"""
Test decorator which skips the test unless current torch version is
bigger than the given number.

>>> # xdoctest: +SKIP
>>> @skip_unless_torch_version_bigger_than(torch_version="2.2")
>>> def test_some_method(self) -> None:
>>> ...
"""

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
current_torch_version = torch.__version__
if current_torch_version >= torch_version:
return func(*args, **kwargs)
sys.exit(TEST_SKIPS[f"torch-version-{torch_version}"].exit_code)

return wrapper

return decorator
Loading