Skip to content

Commit

Permalink
[XLA:GPU] Move Triton ops to mlir::triton::xla namespace
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684460352
  • Loading branch information
Google-ML-Automation committed Oct 10, 2024
1 parent 26c6f1f commit 8d7f723
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 17 deletions.
8 changes: 4 additions & 4 deletions xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -334,28 +334,28 @@ gentbl_cc_library(
(
[
"-gen-op-decls",
"-dialect=xla_triton",
"-dialect=triton_xla",
],
"xla_triton_ops.h.inc",
),
(
[
"-gen-op-defs",
"-dialect=xla_triton",
"-dialect=triton_xla",
],
"xla_triton_ops.cc.inc",
),
(
[
"-gen-dialect-decls",
"-dialect=xla_triton",
"-dialect=triton_xla",
],
"xla_triton_dialect.h.inc",
),
(
[
"-gen-dialect-defs",
"-dialect=xla_triton",
"-dialect=triton_xla",
],
"xla_triton_dialect.cc.inc",
),
Expand Down
8 changes: 2 additions & 6 deletions xla/service/gpu/fusions/triton/xla_triton_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,8 @@ using mlir::SmallVectorImpl;
using mlir::TensorOrMemDesc;
using mlir::Type;
using mlir::ValueRange;
using mlir::triton::DialectInferLayoutInterface;
using mlir::triton::DotOp;

namespace xla {
namespace triton {
namespace mlir::triton::xla {

void XlaTritonDialect::initialize() {
addOperations<
Expand Down Expand Up @@ -124,8 +121,7 @@ LogicalResult SparseDotOp::verify() {
bEncoding);
}

} // namespace triton
} // namespace xla
} // namespace mlir::triton::xla

#define GET_OP_CLASSES
#include "xla/service/gpu/fusions/triton/xla_triton_ops.cc.inc"
8 changes: 4 additions & 4 deletions xla/service/gpu/fusions/triton/xla_triton_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"

def XlaTritonDialect : Dialect {
let name = "xla_triton";
let name = "triton_xla";

let description = [{
This dialect contains ops included in the xla extension point for Triton.
}];

let cppNamespace = "::xla::triton";
let cppNamespace = "::mlir::triton::xla";
}

class XT_Op<string mnemonic, list<Trait> traits = []> :
class TTXLA_Op<string mnemonic, list<Trait> traits = []> :
Op<XlaTritonDialect, mnemonic, traits> {
}

def XT_SparseDotOp : XT_Op<"sparse_dot", [
def TTXLA_SparseDotOp : TTXLA_Op<"sparse_dot", [
Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">]> {
let summary = "sparse dot";
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/tests/sparse_xla_triton_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%B_dot: tensor<64x32xf16, #dot_operand_b>,
%meta_reg: tensor<32x4xi16, #dot_meta_enc>) {
%acc = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
// CHECK-LABEL: xla_triton.sparse_dot
%D = xla_triton.sparse_dot %A_dot, %B_dot, %acc, %meta_reg :
// CHECK-LABEL: triton_xla.sparse_dot
%D = triton_xla.sparse_dot %A_dot, %B_dot, %acc, %meta_reg :
tensor<32x32xf16, #dot_operand_a> meta tensor<32x4xi16,
#dot_meta_enc> * tensor<64x32xf16, #dot_operand_b>
-> tensor<32x32xf32, #mma>
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/tests/xla-opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int main(int argc, char **argv) {
mlir::DialectRegistry registry;
mlir::registerAllExtensions(registry);
registerTritonDialects(registry); // This registers all passes as well.
registry.insert<xla::triton::XlaTritonDialect>();
registry.insert<mlir::triton::xla::XlaTritonDialect>();
xla::gpu::registerTritonFusionTransformsPasses();
xla::gpu::registerGpuFusionTransformsPasses();

Expand Down

0 comments on commit 8d7f723

Please sign in to comment.