From 8d7f72345ce58ce959dcc99f2b1a13e5c672b3e9 Mon Sep 17 00:00:00 2001 From: xla authors Date: Thu, 10 Oct 2024 08:50:39 -0700 Subject: [PATCH] [XLA:GPU] Move Triton ops to mlir::triton::xla namespace PiperOrigin-RevId: 684460352 --- xla/service/gpu/fusions/triton/BUILD | 8 ++++---- xla/service/gpu/fusions/triton/xla_triton_ops.cc | 8 ++------ xla/service/gpu/fusions/triton/xla_triton_ops.td | 8 ++++---- xla/service/gpu/tests/sparse_xla_triton_op.mlir | 4 ++-- xla/service/gpu/tests/xla-opt.cc | 2 +- 5 files changed, 13 insertions(+), 17 deletions(-) diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index 9780723fdee20..2ae038a17d4e9 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -334,28 +334,28 @@ gentbl_cc_library( ( [ "-gen-op-decls", - "-dialect=xla_triton", + "-dialect=triton_xla", ], "xla_triton_ops.h.inc", ), ( [ "-gen-op-defs", - "-dialect=xla_triton", + "-dialect=triton_xla", ], "xla_triton_ops.cc.inc", ), ( [ "-gen-dialect-decls", - "-dialect=xla_triton", + "-dialect=triton_xla", ], "xla_triton_dialect.h.inc", ), ( [ "-gen-dialect-defs", - "-dialect=xla_triton", + "-dialect=triton_xla", ], "xla_triton_dialect.cc.inc", ), diff --git a/xla/service/gpu/fusions/triton/xla_triton_ops.cc b/xla/service/gpu/fusions/triton/xla_triton_ops.cc index 4418e95566e5a..56260cd43d781 100644 --- a/xla/service/gpu/fusions/triton/xla_triton_ops.cc +++ b/xla/service/gpu/fusions/triton/xla_triton_ops.cc @@ -46,11 +46,8 @@ using mlir::SmallVectorImpl; using mlir::TensorOrMemDesc; using mlir::Type; using mlir::ValueRange; -using mlir::triton::DialectInferLayoutInterface; -using mlir::triton::DotOp; -namespace xla { -namespace triton { +namespace mlir::triton::xla { void XlaTritonDialect::initialize() { addOperations< @@ -124,8 +121,7 @@ LogicalResult SparseDotOp::verify() { bEncoding); } -} // namespace triton -} // namespace xla +} // namespace mlir::triton::xla #define GET_OP_CLASSES #include "xla/service/gpu/fusions/triton/xla_triton_ops.cc.inc" diff --git a/xla/service/gpu/fusions/triton/xla_triton_ops.td b/xla/service/gpu/fusions/triton/xla_triton_ops.td index db9f4abf18342..c72ab32e34ea4 100644 --- a/xla/service/gpu/fusions/triton/xla_triton_ops.td +++ b/xla/service/gpu/fusions/triton/xla_triton_ops.td @@ -24,20 +24,20 @@ include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" include "triton/Dialect/Triton/IR/TritonTypes.td" def XlaTritonDialect : Dialect { - let name = "xla_triton"; + let name = "triton_xla"; let description = [{ This dialect contains ops included in the xla extension point for Triton. }]; - let cppNamespace = "::xla::triton"; + let cppNamespace = "::mlir::triton::xla"; } -class XT_Op traits = []> : +class TTXLA_Op traits = []> : Op { } -def XT_SparseDotOp : XT_Op<"sparse_dot", [ +def TTXLA_SparseDotOp : TTXLA_Op<"sparse_dot", [ Pure, DeclareOpInterfaceMethods, TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">]> { let summary = "sparse dot"; diff --git a/xla/service/gpu/tests/sparse_xla_triton_op.mlir b/xla/service/gpu/tests/sparse_xla_triton_op.mlir index beddba1f2e934..f856f07f0b461 100644 --- a/xla/service/gpu/tests/sparse_xla_triton_op.mlir +++ b/xla/service/gpu/tests/sparse_xla_triton_op.mlir @@ -13,8 +13,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %B_dot: tensor<64x32xf16, #dot_operand_b>, %meta_reg: tensor<32x4xi16, #dot_meta_enc>) { %acc = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - // CHECK-LABEL: xla_triton.sparse_dot - %D = xla_triton.sparse_dot %A_dot, %B_dot, %acc, %meta_reg : + // CHECK-LABEL: triton_xla.sparse_dot + %D = triton_xla.sparse_dot %A_dot, %B_dot, %acc, %meta_reg : tensor<32x32xf16, #dot_operand_a> meta tensor<32x4xi16, #dot_meta_enc> * tensor<64x32xf16, #dot_operand_b> -> tensor<32x32xf32, #mma> diff --git a/xla/service/gpu/tests/xla-opt.cc b/xla/service/gpu/tests/xla-opt.cc index cad268b196086..8895104904245 100644 --- a/xla/service/gpu/tests/xla-opt.cc +++ b/xla/service/gpu/tests/xla-opt.cc @@ -24,7 +24,7 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; mlir::registerAllExtensions(registry); registerTritonDialects(registry); // This registers all passes as well. - registry.insert(); + registry.insert(); xla::gpu::registerTritonFusionTransformsPasses(); xla::gpu::registerGpuFusionTransformsPasses();