-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for varying vector shapes #247
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
UnitAttr, | ||
Value, | ||
) | ||
from typing import Optional, Callable, Any, List, Tuple | ||
from typing import Optional, Callable, Any, List, Tuple, Sequence | ||
from .._support.tracing import CapturedTrace | ||
from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence | ||
from ..lang.global_symbols import * | ||
|
@@ -25,6 +25,7 @@ | |
Reduction, | ||
GetResult, | ||
IterArg, | ||
Reshape, | ||
) | ||
from .constraints import ( | ||
Constraint, | ||
|
@@ -192,6 +193,20 @@ def simplify_index(index: IndexExpr) -> IndexExpr: | |
return subs_idxc(index.subs(mapping)) | ||
|
||
|
||
def is_reshape_needed( | ||
node: CustomOp, | ||
node_vector_shapes: dict[IndexSymbol, int], | ||
vector_shapes: dict[IndexSymbol, int], | ||
) -> bool: | ||
for dim in node.type.symbolic_shape: | ||
if dim not in vector_shapes: | ||
# Ignore nodes that are not used in both mmas. | ||
return False | ||
if node_vector_shapes[dim] != vector_shapes[dim]: | ||
return True | ||
return False | ||
|
||
|
||
def get_mma_dimensional_mapping( | ||
trace: CapturedTrace, | ||
hardware_constraint: HardwareConstraint, | ||
|
@@ -243,6 +258,32 @@ def is_mma(node): | |
reduction.anchor = custom | ||
|
||
mma_slices = {get_custom(x): capture_mma_slices(get_custom(x)) for x in mma_nodes} | ||
|
||
# Determine if any reshapes are required. Reshapes are added for | ||
# chained matmuls when the vector shapes of the operands in one matmul | ||
# differ from those in another matmul. | ||
for src in mma_nodes: | ||
custom_src = get_custom(src) | ||
for dst in mma_nodes: | ||
if src == dst: | ||
continue | ||
custom_dst = get_custom(dst) | ||
lhs_slice = capture_backward_slice(custom_dst.lhs) | ||
rhs_slice = capture_backward_slice(custom_dst.rhs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we remove the double nested loop, and instead find the mma op inside lhs_slice/rhs_slice? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes we can, but then we would have to iterate through all the operators in the slice to see if any of them has an mma as an argument. That would end up with a triple-nested-loop. Unless you are thinking of something else? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay I think I have a cleaner version now. let me know what you think. |
||
if src in lhs_slice or src in rhs_slice: | ||
with custom_dst.graph.inserting_before(dst): | ||
for i, arg in custom_dst.node_args.items(): | ||
if is_reshape_needed( | ||
arg, custom_dst.vector_shapes, custom_src.vector_shapes | ||
): | ||
reshape = Reshape( | ||
arg.fx_node, custom_src.vector_shapes | ||
).add_to_graph(custom.graph) | ||
custom_reshape = get_custom(reshape) | ||
custom_reshape.vector_shapes = custom.vector_shapes | ||
custom_reshape.anchor = custom | ||
custom.update_arg(i, reshape) | ||
|
||
return mapping, mma_slices | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to do anything special to ensure this ordering of expansion is correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, but do you have an example in mind?