-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharded Llama missing cache update in exported MLIR #271
Comments
Here is a PR that helps address this. It uses the approach of inserting an in-place device placement torch FX op, that does not get materialized as an op, but sets the affinity for the corresponding function argument. This seems a bit brittle. There is one other approach where we insert attributes into the Torch FX graph after generation. import torch
from torch.fx import Graph, Node
def create_fx_function_with_custom_attributes():
graph = Graph()
# Create input Node
input_node = graph.placeholder('x', type_expr=torch.Tensor)
# Set custom attributes on the input node
input_node.custom_attr = "input_custom_value"
# Create a parameter Node
param_node = graph.placeholder('weight', type_expr=torch.Tensor)
param_node.custom_attr = "weight_custom_value"
# Create an operation Node
output_node = graph.call_function(torch.matmul, args=(input_node, param_node))
# Set the output
graph.output(output_node)
# Create a GraphModule from this Graph
module = torch.fx.GraphModule(torch.nn.Module(), graph)
return module
# Create the FX function
fx_module = create_fx_function_with_custom_attributes()
# Print the generated code
print(fx_module.code)
# Access custom attributes
for node in fx_module.graph.nodes:
if hasattr(node, 'custom_attr'):
print(f"Node: {node.name}, Custom Attribute: {node.custom_attr}") Then we convert these attributes to function argument attributes in MLIR after the main conversion has occurred. @stellaraccident, what other approach is there? |
I think that there are some disconnected pieces to make this work. Can we confirm that we really do need these function annotations? One thing is that those This might need to be done in our call to import, passing device placements for some arguments vs through export goo. |
If we don't set the |
The sharded variant of the exported Lllama from this test has missing IR for the update of the paged cache.
Note that the equivalent unsharded model does not exhibit this problem.
My hypothesis is it got erroneously optimized out due to dead code elimination as we are not properly generating cache update code to be really in-place. Probably some problem with interaction with
flow.tensor.trasfer
as it actually introduces a new tensor value. But when not exporting we just return the same tensor so in-place semantics work fine.Here are both the sharded and unsharded exported MLIR programs-mlir.zip.
The text was updated successfully, but these errors were encountered: