Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Oct 25, 2024
1 parent 06131ce commit fd8724f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 46 deletions.
2 changes: 1 addition & 1 deletion frontends/torch-frontend/scripts/build_and_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ cmake --build ./build --target all
if [[ $TORCH_FRONTEND_TEST == "ON" ]]; then
python3 -m pip install -r test-requirements.txt
install_mhlo_tools
PYTHONPATH=build/python_packages/:build/torch_mlir_build/python_packages/torch_mlir TORCH_DISABLE_NATIVE_FUNCOL=1 python3 -m pytest torch-frontend/python/test
PYTHONPATH=build/python_packages/:build/torch_mlir_build/python_packages/torch_mlir TORCH_DISABLE_NATIVE_FUNCOL=1 python3 -m pytest -m "not attention_rewriter" torch-frontend/python/test
fi

popd
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,51 @@ def AttnReplacement5(q, k, v, attn_mask, inv_scale):
)


# LLaMA aten attention op pattern
def LLaMAAttnPattern(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim):
transpose_3 = torch.ops.aten.transpose.int(key, 2, 3)
expand_2 = torch.ops.aten.expand.default(query, [batch, num_head, seq_len, head_dim])
clone = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format)
_unsafe_view_3 = torch.ops.aten._unsafe_view.default(clone, [fused_batch, seq_len, head_dim])
expand_3 = torch.ops.aten.expand.default(transpose_3, [batch, num_head, head_dim, seq_len])
clone_1 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format)
_unsafe_view_4 = torch.ops.aten._unsafe_view.default(clone_1, [fused_batch, head_dim, seq_len])
bmm = torch.ops.aten.bmm.default(_unsafe_view_3, _unsafe_view_4)
_unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm, [batch, num_head, seq_len, seq_len])
div = torch.ops.aten.div.Tensor(_unsafe_view_5, inv_scale)
add_5 = torch.ops.aten.add.Tensor(div, attn_mask)
maximum = torch.ops.aten.maximum.default(add_5, min_val)
_softmax = torch.ops.aten._softmax.default(maximum, -1, False)
_to_copy_10 = torch.ops.aten._to_copy.default(_softmax, dtype = torch.float16)
expand_4 = torch.ops.aten.expand.default(_to_copy_10, [batch, num_head, seq_len, seq_len])
view_8 = torch.ops.aten.view.default(expand_4, [fused_batch, seq_len, seq_len]); expand_4 = None
expand_5 = torch.ops.aten.expand.default(value, [batch, num_head, seq_len, head_dim])
clone_2 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format)
_unsafe_view_6 = torch.ops.aten._unsafe_view.default(clone_2, [fused_batch, seq_len, head_dim])
bmm_1 = torch.ops.aten.bmm.default(view_8, _unsafe_view_6)
_unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm_1, [batch, num_head, seq_len, head_dim])
return _softmax, _unsafe_view_5


def LLaMAAttnReplacement(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim):
# q, k, v needs to be transposed for flash attn v2
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
out, q_pad, k_pad, v_pad, out_pad, softmax_lse, S_dmask, rng_state = torch.ops.byteir.flash_attn_fwd(
query,
key,
value,
0.0,
1.0/inv_scale,
True,
False
)
# output also needs to be transposed
out = out.transpose(1, 2)
return out, out


def canonicalize_graph_before_replacement(gm):
for n in gm.graph.nodes:
if n.op == "call_module":
Expand Down Expand Up @@ -243,4 +288,5 @@ def fx_replace_attn_pattern(gm: torch.fx.GraphModule):
torch.fx.replace_pattern(gm, AttnPattern3, AttnReplacement3)
torch.fx.replace_pattern(gm, AttnPattern4, AttnReplacement4)
torch.fx.replace_pattern(gm, AttnPattern5, AttnReplacement5)
torch.fx.replace_pattern(gm, LLaMAAttnPattern, LLaMAAttnReplacement)
return gm
Original file line number Diff line number Diff line change
Expand Up @@ -124,50 +124,6 @@ def unsafe_index_put_pattern(self, indices, values, accumulate):
def unsafe_index_put_replacement(self, indices, values, accumulate):
return torch.ops.aten.index_put_.hacked_twin(self, indices, values, accumulate)

# LLaMA aten attention op pattern
def LLaMAAttnPattern(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim):
transpose_3 = torch.ops.aten.transpose.int(key, 2, 3)
expand_2 = torch.ops.aten.expand.default(query, [batch, num_head, seq_len, head_dim])
clone = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format)
_unsafe_view_3 = torch.ops.aten._unsafe_view.default(clone, [fused_batch, seq_len, head_dim])
expand_3 = torch.ops.aten.expand.default(transpose_3, [batch, num_head, head_dim, seq_len])
clone_1 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format)
_unsafe_view_4 = torch.ops.aten._unsafe_view.default(clone_1, [fused_batch, head_dim, seq_len])
bmm = torch.ops.aten.bmm.default(_unsafe_view_3, _unsafe_view_4)
_unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm, [batch, num_head, seq_len, seq_len])
div = torch.ops.aten.div.Tensor(_unsafe_view_5, inv_scale)
add_5 = torch.ops.aten.add.Tensor(div, attn_mask)
maximum = torch.ops.aten.maximum.default(add_5, min_val)
_softmax = torch.ops.aten._softmax.default(maximum, -1, False)
_to_copy_10 = torch.ops.aten._to_copy.default(_softmax, dtype = torch.float16)
expand_4 = torch.ops.aten.expand.default(_to_copy_10, [batch, num_head, seq_len, seq_len])
view_8 = torch.ops.aten.view.default(expand_4, [fused_batch, seq_len, seq_len]); expand_4 = None
expand_5 = torch.ops.aten.expand.default(value, [batch, num_head, seq_len, head_dim])
clone_2 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format)
_unsafe_view_6 = torch.ops.aten._unsafe_view.default(clone_2, [fused_batch, seq_len, head_dim])
bmm_1 = torch.ops.aten.bmm.default(view_8, _unsafe_view_6)
_unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm_1, [batch, num_head, seq_len, head_dim])
return _softmax, _unsafe_view_5


def LLaMAAttnReplacement(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim):
# q, k, v needs to be transposed for flash attn v2
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
out, q_pad, k_pad, v_pad, out_pad, softmax_lse, S_dmask, rng_state = torch.ops.byteir.flash_attn_fwd(
query,
key,
value,
0.0,
1.0/inv_scale,
True,
False
)
# output also needs to be transposed
out = out.transpose(1, 2)
return out, out


def get_none_indices(fx_g: torch.fx.GraphModule) -> List[int]:
none_indices = []
Expand Down Expand Up @@ -206,7 +162,6 @@ def preprocess_fx_graph(fx_graph: torch.fx.GraphModule):

torch.fx.replace_pattern(fx_graph, squeeze_dims_pattern, squeeze_dims_replacement)
torch.fx.replace_pattern(fx_graph, unsafe_index_put_pattern, unsafe_index_put_replacement)
torch.fx.replace_pattern(fx_graph, LLaMAAttnPattern, LLaMAAttnReplacement)
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
was_list_replaced = _list_return_to_tuple_return(fx_graph)
removed_none_indexes = _remove_nones(fx_graph)
Expand Down

0 comments on commit fd8724f

Please sign in to comment.