Skip to content

Commit

Permalink
[torch-frontend] support lowering triton.xxx to stablehlo.custom_call
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Oct 25, 2024
1 parent 07f9744 commit 1c238ae
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,14 @@ class ConvertFlashAttnBwdOp : public OpConversionPattern<OperatorOp> {
}
};

// torch.operator "byteir.flash_attn_kvcache"
// operands: q, k, v, kcache, vcache, seqlen_k, softmax_scale, causal
// results: out, softmax_lse
//
// CustomCall:
// operands: q, k, v, kcache, vcache, seqlen_k
// Attributes: softmax_scale, causal
// results: out, softmax_lse
class ConvertFlashAttnKVCacheOp : public OpConversionPattern<OperatorOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -1103,6 +1111,40 @@ class ConvertFlashAttnKVCacheOp : public OpConversionPattern<OperatorOp> {
};
} // namespace

namespace {
class ConvertTritonOp : public OpConversionPattern<OperatorOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(OperatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opName = adaptor.getName();
if (!opName.starts_with("triton."))
return rewriter.notifyMatchFailure(op, "op name not match");

auto operands = adaptor.getOperands();
SmallVector<Type> resultTypes;
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
resultTypes))) {
return op.emitError("could not convert output types");
}

std::vector<NamedAttribute> byteir_attrs;
auto attrs = getDefaultAttrs(rewriter);
attrs.emplace_back(rewriter.getStringAttr("call_target_name"),
rewriter.getStringAttr(opName));
attrs.emplace_back(rewriter.getStringAttr(getCustomCallAttrName()),
rewriter.getDictionaryAttr(byteir_attrs));

auto customCallOp = rewriter.create<stablehlo::CustomCallOp>(
op->getLoc(), resultTypes, operands, ArrayRef<NamedAttribute>{attrs});
rewriter.replaceOp(op, customCallOp.getResults());
return success();
}
};
} // namespace

// aten.nonzero
namespace {
class ConvertAtenNonzeroOp : public OpConversionPattern<AtenNonzeroOp> {
Expand Down Expand Up @@ -1392,6 +1434,7 @@ class ConvertTorchToCustomCall
patterns.add<ConvertFlashAttnFwdOp>(typeConverter, context);
patterns.add<ConvertFlashAttnBwdOp>(typeConverter, context);
patterns.add<ConvertFlashAttnKVCacheOp>(typeConverter, context);
patterns.add<ConvertTritonOp>(typeConverter, context);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def inner(f):


@op(
"byteir::flash_attn_fwd(Tensor q, Tensor k, Tensor v, float dropout_p, float softmax_scale, bool casual, bool return_softmax) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"
"byteir::flash_attn_fwd(Tensor q, Tensor k, Tensor v, float dropout_p, float softmax_scale, bool causal, bool return_softmax) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"
)
def byteir_flash_attn_fwd(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
sizes = q.shape
Expand Down Expand Up @@ -56,7 +56,7 @@ def byteir_flash_attn_fwd(q, k, v, dropout_p, softmax_scale, causal, return_soft


@op(
"byteir::flash_attn_bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, float dropout_p, float softmax_scale, bool casual, Tensor rng) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"
"byteir::flash_attn_bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, float dropout_p, float softmax_scale, bool causal, Tensor rng) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"
)
def byteir_flash_attn_bwd(
dout, q, k, v, out, softmax_lse, dropout_p, softmax_scale, causal, rng_state
Expand All @@ -83,7 +83,7 @@ def byteir_flash_attn_bwd(


@op(
"byteir::flash_attn_kvcache(Tensor q, Tensor k, Tensor v, Tensor kcache, Tensor vcache, Tensor seqlen_k, float softmax_scale, bool casual) -> (Tensor, Tensor)"
"byteir::flash_attn_kvcache(Tensor q, Tensor k, Tensor v, Tensor kcache, Tensor vcache, Tensor seqlen_k, float softmax_scale, bool causal) -> (Tensor, Tensor)"
)
def byteir_flash_attn_kvcache(q, k, v, kcache, vcache, seqlen_k, softmax_scale, causal):
sizes = q.shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,12 @@ func.func @torch.aten.upsample_nearest2d.vec(%arg0: !torch.vtensor<[1,3,10,20],f
// CHECK: stablehlo.custom_call
// CHECK-SAME: @byteir.resize
// CHECK-NOT: torch.aten.upsample_nearest2d.vec

func.func @torch.triton.custom_op(%arg0: !torch.vtensor<[128,256],f32>, %arg1: !torch.vtensor<[128,256],f32>) -> !torch.vtensor<[128,256],f32> {
%0 = torch.operator "triton.custom_op"(%arg0, %arg1) : (!torch.vtensor<[128,256],f32>, !torch.vtensor<[128,256],f32>) -> (!torch.vtensor<[128,256],f32>)
return %0 : !torch.vtensor<[128,256],f32>
}
// CHECK-LABEL: func.func @torch.triton.custom_op
// CHECK: stablehlo.custom_call
// CHECK-SAME: @triton.custom_op
// CHECK-NOT: torch.operator

0 comments on commit 1c238ae

Please sign in to comment.