Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch-frontend] add stablehlo IRs for Mixtral model. #254

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from functorch.compile import aot_module, default_partition
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch._decomp import (
Expand All @@ -18,53 +19,6 @@
)
from transformers.models.mixtral.configuration_mixtral import MixtralConfig


def fake_export_whole_mixtral():
model_conf = MixtralConfig()
with FakeTensorMode():
mixtral = MixtralModel(model_conf)
# step 1: fake init
print(mixtral)

"""
MixtralModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x MixtralDecoderLayer(
(self_attn): MixtralSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): MixtralRotaryEmbedding()
)
(block_sparse_moe): MixtralSparseMoeBlock(
(gate): Linear(in_features=4096, out_features=8, bias=False)
(experts): ModuleList(
(0-7): 8 x MixtralBlockSparseTop2MLP(
(w1): Linear(in_features=4096, out_features=14336, bias=False)
(w2): Linear(in_features=14336, out_features=4096, bias=False)
(w3): Linear(in_features=4096, out_features=14336, bias=False)
(act_fn): SiLU()
)
)
)
(input_layernorm): MixtralRMSNorm()
(post_attention_layernorm): MixtralRMSNorm()
)
)
(norm): MixtralRMSNorm()
)
"""

# step 2: torch.export
bsz = 5
seq_len = 7
vocab_size = 32000
token_ids = torch.randint(0, vocab_size, (bsz, seq_len))
exported_mixtral = torch.export.export(mixtral, (token_ids,))
print(exported_mixtral)

def export_mixtral_decoding_layer():
torch._dynamo.config.capture_dynamic_output_shape_ops = True
model_conf = MixtralConfig(hidden_size=32)
Expand All @@ -79,5 +33,48 @@ def export_mixtral_decoding_layer():
module = torch_frontend.compile_dynamo_model(mixtral_decoder_layer, output_type="stablehlo")
print(module.operation.get_asm())


def export_fake_whole_mixtral_model():
torch._dynamo.config.capture_dynamic_output_shape_ops = True
mixtral_config = MixtralConfig(attn_implementation="eager")
bsz = 3
seqlen = 1
past_kv_seqlen = 7
with FakeTensorMode(shape_env=ShapeEnv()):
mixtral_model = MixtralModel(mixtral_config)
input_ids = torch.randint(0, mixtral_config.vocab_size, (bsz, seqlen))

past_key_values = []
for i in range(mixtral_config.num_hidden_layers):
past_key_values.append(
(
torch.rand(
bsz,
mixtral_config.num_key_value_heads,
past_kv_seqlen,
mixtral_config.hidden_size
// mixtral_config.num_attention_heads,
),
torch.rand(
bsz,
mixtral_config.num_key_value_heads,
past_kv_seqlen,
mixtral_config.hidden_size
// mixtral_config.num_attention_heads,
),
)
)
exported_mixtral = torch.export.export(
mixtral_model,
(input_ids, None, None, past_key_values),
)
import torch_frontend
module = torch_frontend.compile_dynamo_model(
exported_mixtral,
output_type="stablehlo",
)
print(module.operation.get_asm())

if __name__ == "__main__":
export_fake_whole_mixtral_model()
export_mixtral_decoding_layer()
32,271 changes: 32,271 additions & 0 deletions frontends/torch-frontend/examples/inference/mixtral/mixtral.stablehlo.mlir

Large diffs are not rendered by default.

Large diffs are not rendered by default.

158 changes: 158 additions & 0 deletions frontends/torch-frontend/third_party/patches/fx_importer.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py
index 381f8f9a..a5e2d32b 100644
--- a/python/torch_mlir/extras/fx_importer.py
+++ b/python/torch_mlir/extras/fx_importer.py
@@ -52,6 +52,10 @@ from torch._subclasses import (
FakeTensor as TorchFakeTensor,
)

+from torch.distributed._functional_collectives import (
+ AsyncCollectiveTensor as TorchAsyncCollectiveTensor
+)
+
from torch.fx import (
Graph,
GraphModule,
@@ -924,6 +928,19 @@ class ContextCache:
tensor_meta = node.meta.get("tensor_meta")
val = node.meta.get("val")
sparsity = node.meta.get("sparsity", None)
+ # Some nodes returns a list, like torch.ops.aten.unbind.int
+ if isinstance(tensor_meta, List) or isinstance(val, List):
+ if tensor_meta is not None and all(x is not None for x in tensor_meta):
+ # Assume that all results in the list are tensors.
+ # TODO: Solve this assumption
+ return IrType.parse("!torch.list<vtensor>", context=self._c)
+ elif val is not None and all(x is not None for x in val):
+ return IrType.parse("!torch.list<vtensor>", context=self._c)
+ else:
+ raise NotImplementedError(
+ f"FIXME: Unsupported placeholder node (this often indicates that a necessary) "
+ f"fx preprocessing pass was not run): {node.meta}"
+ )
except KeyError as e:
raise RuntimeError(
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
@@ -1035,6 +1052,7 @@ class GraphNodeImporter:
"_on_node_produced",
"_v",
"_multi_result_nodes",
+ "_list_return_nodes",
"fx_importer",
]

@@ -1058,6 +1076,9 @@ class GraphNodeImporter:
# They will have their getitem calls short-circuited.
self._multi_result_nodes: Set[torch_fx.Node] = set()

+ # Stores the node that returns a list, like aten.unbind.int
+ self._list_return_nodes: Set[torch_fx.Node] = set()
+
def bind_node_value(
self,
node: Node,
@@ -1213,6 +1234,23 @@ class GraphNodeImporter:
f"notify developers if this case happens "
f"(at {loc})."
)
+ elif getitem_ref in self._list_return_nodes:
+ fx_list_return_value = self._v[(getitem_ref, 0)]
+ operands = [
+ fx_list_return_value,
+ self._import_default_value(loc, getitem_index, torch.IntType)
+ ]
+
+ # We trust the tensor type in FX graph, even if it's a getitem
+ # from a value of MLIR ListType.
+ operation = Operation.create(
+ "torch.aten.__getitem__.t",
+ results=(self._cc.node_val_to_type(node),),
+ operands = operands,
+ loc=loc
+ )
+ for i, value in enumerate(operation.results):
+ self._v[(node, i)] = value
else:
raise NotImplementedError(
f"General getitem access to non-multi-result ops"
@@ -1676,6 +1714,10 @@ class GraphNodeImporter:
# Unary return directly maps a single meta["val"] and cannot be subscripted.
# if "tensor_meta" is None, this will throw unsupported placeholder node error
result_types = [self._cc.node_val_to_type(node)]
+
+ # separately handle ops returning list.
+ if str(result_types[0]).startswith("!torch.list"):
+ self._list_return_nodes.add(node)
elif return_count == 0:
# Some torch ops do have 0 returns, and these are supported with ZeroResults
# op trait. Python bindings for IR creation allow us to pass empty result_types
@@ -1717,6 +1759,8 @@ def _make_vtensor_literal_op(
) -> Operation:
mapping = py_attr_tracker.track(tensor)
if mapping.is_empty:
+ # unwrap from TorchAsyncCollectiveTensor
+ tensor = tensor.elem if isinstance(tensor, TorchAsyncCollectiveTensor) else tensor
# check support for bfloat16
assert not (
tensor.dtype == torch.bfloat16 and ml_dtypes is None
@@ -1732,29 +1776,42 @@ def _make_vtensor_literal_op(
# detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
- np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
- # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not
- # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling
- # 0d tensors.
- if np_tensor.size == 1:
+
+ # NOTE: if we torch.export a torch.nn.Module under fake mode, the parameters in the fx.GraphModule will be FakeTensor.
+ # So we specifically handle FakeTensor here by creating a splat DenseElementsAttr with value 0.
+ if isinstance(tensor, TorchFakeTensor):
+ array = np.array([0]).astype(npy_dtype)
try:
- dtype = tensor.dtype
- element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
+ element_type = TORCH_DTYPE_TO_MLIR_TYPE[tensor.dtype]()
except KeyError:
- raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")
+ raise TypeError(f"Could not map Torch dtype {tensor.dtype} to an MLIR type")
elements_attr = DenseElementsAttr.get(
- type=element_type, array=np_tensor, shape=np_tensor.shape
+ array=array, type=element_type, shape=list(tensor.shape)
)
else:
- bytes_view = np_tensor.view(npy_dtype)
- tensor_type = create_mlir_tensor_type(tensor)
- shape_desc = "_".join([str(d) for d in tensor.shape])
- blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
- elements_attr = DenseResourceElementsAttr.get_from_buffer(
- bytes_view,
- blob_name,
- tensor_type,
- )
+ np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
+ # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not
+ # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling
+ # 0d tensors.
+ if np_tensor.size == 1:
+ try:
+ dtype = tensor.dtype
+ element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
+ except KeyError:
+ raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")
+ elements_attr = DenseElementsAttr.get(
+ type=element_type, array=np_tensor, shape=np_tensor.shape
+ )
+ else:
+ bytes_view = np_tensor.view(npy_dtype)
+ tensor_type = create_mlir_tensor_type(tensor)
+ shape_desc = "_".join([str(d) for d in tensor.shape])
+ blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
+ elements_attr = DenseResourceElementsAttr.get_from_buffer(
+ bytes_view,
+ blob_name,
+ tensor_type,
+ )
mapping.value = elements_attr
else:
elements_attr = mapping.value
77 changes: 0 additions & 77 deletions frontends/torch-frontend/third_party/patches/fx_list_return.patch

This file was deleted.

Loading