Skip to content

Commit

Permalink
[XLA:GPU] Remove untested int4 path in the Triton fusion emitter.
Browse files Browse the repository at this point in the history
`int4` support was added for only very restricted cases in the Triton matmul
emitter, and was not yet intended to be supported in
`triton_fusion_emitter.cc`.

Make sure to explicitly remove the dependency from
`triton_fusion_emitter_legacy_matmul.cc` onto `emitter_helpers.h`. The helpers
were designed to be used with the new approach to Triton fusions, and the code
shouldn't be shared.

PiperOrigin-RevId: 689704984
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Oct 25, 2024
1 parent f548642 commit 4663f04
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 60 deletions.
3 changes: 2 additions & 1 deletion xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,12 @@ cc_library(
),
hdrs = ["triton_fusion_emitter_legacy_matmul.h"],
deps = [
":emitter_helpers",
"//xla:comparison_util",
"//xla:literal",
"//xla:shape_util",
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/mlir_hlo",
Expand Down
30 changes: 0 additions & 30 deletions xla/service/gpu/fusions/triton/emitter_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ absl::StatusOr<Type> TritonType(mlir::OpBuilder b, PrimitiveType t) {
return b.getI1Type();
case S8:
return b.getI8Type();
case S4: // The unpacking to i8 is supported by the emitter.
// We pass the s4 tensor as i8 tensor with the minor dimension having 2x
// less elements and unpack in the inner loop of the triton kernel.
return b.getI8Type();
case F8E5M2:
return b.getFloat8E5M2Type();
case F8E4M3FN:
Expand Down Expand Up @@ -442,30 +438,4 @@ absl::StatusOr<ScalarOrTensor> EmitConstant(ImplicitLocOpBuilder& b,
return CreateConst(b, ty, ScalarConstantValue<double>(constant, F64), shape);
}

// Emit sequence of operations for unpacking 2xi4 -> i8.
absl::StatusOr<Value> EmitUnpackInt4(ImplicitLocOpBuilder& b,
const HloInstruction* hlo,
int64_t unpack_dim_idx, Value value) {
VLOG(6) << "EmitUnpackInt4: " << hlo->ToString();
auto input_type = mlir::cast<mlir::RankedTensorType>(value.getType());
if (input_type.getShape().size() != 2) {
return absl::InvalidArgumentError(
absl::StrCat("UnpackInt4 works only for 2d inputs: ", hlo->ToString()));
}
// We use shifts instead the mask because we need to keep the sign bit.
Value shift4 =
Splat(b, CreateConst(b, b.getI8Type(), 4), input_type.getShape())
.UnwrapUnsafe();
Value lo = b.create<ma::ShRSIOp>(b.create<ma::ShLIOp>(value, shift4), shift4);
Value hi = b.create<ma::ShRSIOp>(value, shift4);
Value result = b.create<mt::JoinOp>(hi, lo);
if (unpack_dim_idx == 0) {
result = b.create<mt::TransOp>(result, b.getDenseI32ArrayAttr({0, 2, 1}));
}
SmallVector<int64_t> result_shape(input_type.getShape());
result_shape[unpack_dim_idx] *= 2;
auto type = mlir::RankedTensorType::get(result_shape, b.getI8Type());
return b.create<mt::ReshapeOp>(type, result, /*allow_reorder=*/false);
}

} // namespace xla::gpu::triton
5 changes: 0 additions & 5 deletions xla/service/gpu/fusions/triton/emitter_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,6 @@ absl::StatusOr<mlir::Value> EmitElementwise(
const se::DeviceDescription& device_info, const HloInstruction& hlo,
mlir::ValueRange inputs);

// Emit sequence of operations for unpacking 2xi4 -> i8.
absl::StatusOr<mlir::Value> EmitUnpackInt4(mlir::ImplicitLocOpBuilder& b,
const HloInstruction* hlo,
int64_t unpack_dim_idx,
mlir::Value value);
} // namespace xla::gpu::triton

#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_EMITTER_HELPERS_H_
19 changes: 4 additions & 15 deletions xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ limitations under the License.
#include "xla/service/gpu/fusions/triton/emitter_helpers.h"
#include "xla/service/gpu/fusions/triton/passes.h"
#include "xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h"
#include "xla/service/gpu/fusions/triton/triton_support.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"
Expand Down Expand Up @@ -147,7 +146,6 @@ using ::xla::gpu::triton::Cast;
using ::xla::gpu::triton::CreateConst;
using ::xla::gpu::triton::EmitConstant;
using ::xla::gpu::triton::EmitElementwise;
using ::xla::gpu::triton::EmitUnpackInt4;
using ::xla::gpu::triton::GetPaddedTileSizes;
using ::xla::gpu::triton::ScalarOrTensor;
using ::xla::gpu::triton::StorageType;
Expand Down Expand Up @@ -738,19 +736,8 @@ absl::StatusOr<ScalarOrTensor> EmitScope(
absl::flat_hash_map<const HloInstruction*, ScalarOrTensor>& values) {
for (const HloInstruction* hlo : instructions) {
ScalarOrTensor result;
if (hlo->opcode() == HloOpcode::kConvert &&
hlo->operand(0)->shape().element_type() == S4) {
TF_ASSIGN_OR_RETURN(
auto unpacked,
EmitUnpackInt4(b, hlo, /*unpack_dim_idx=*/0,
values[hlo->operand(0)].UnwrapUnsafe()));
std::vector<Value> operands({unpacked});
TF_ASSIGN_OR_RETURN(
Value elementwise_result,
EmitElementwise(b, libdevice_path, device_info, *hlo, operands));
result = ScalarOrTensor(elementwise_result);
} else if (hlo->opcode() == HloOpcode::kConcatenate ||
hlo->opcode() == HloOpcode::kDynamicSlice) {
if (hlo->opcode() == HloOpcode::kConcatenate ||
hlo->opcode() == HloOpcode::kDynamicSlice) {
// Parameter loads and their concatenations are handled outside EmitScope.
TF_RET_CHECK(values.contains(hlo)) << hlo->ToString();
continue;
Expand Down Expand Up @@ -1090,6 +1077,8 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
Type ir_type;
if (type == U16) {
ir_type = b.getI16Type();
} else if (type == S4) {
ir_type = b.getI8Type();
} else {
TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/literal.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h"
#include "xla/mlir_hlo/mhlo/transforms/transformation_helpers.h"
#include "xla/primitive_util.h"
#include "xla/service/algorithm_util.h"
#include "xla/service/gpu/fusions/triton/emitter_helpers.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h"
Expand All @@ -82,6 +82,7 @@ limitations under the License.
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
Expand All @@ -104,11 +105,50 @@ using ::mlir::Type;
using ::mlir::Value;
using ::mlir::ValueRange;

using ::xla::gpu::triton::StorageType;
using ::xla::gpu::triton::TritonType;

namespace {

absl::StatusOr<Type> TritonType(mlir::OpBuilder b, PrimitiveType t) {
switch (t) {
case F64:
return b.getF64Type();
case F32:
return b.getF32Type();
case F16:
return b.getF16Type();
case BF16:
return b.getBF16Type();
case S64:
return b.getI64Type();
case S32:
return b.getI32Type();
case S16:
return b.getI16Type();
case PRED:
return b.getI1Type();
case S8:
return b.getI8Type();
case S4: // The unpacking to i8 is supported by the emitter.
// We pass the s4 tensor as i8 tensor with the minor dimension having 2x
// less elements and unpack in the inner loop of the triton kernel.
return b.getI8Type();
case F8E5M2:
return b.getFloat8E5M2Type();
case F8E4M3FN:
return b.getFloat8E4M3FNType();
default:
return absl::UnimplementedError(
absl::StrCat("This type is not supported yet: ",
primitive_util::LowercasePrimitiveTypeName(t)));
}
}

Type StorageType(mlir::OpBuilder b, Type t) {
if (t.isInteger(1)) {
return b.getI8Type();
}
return t;
}

// Create a scalar constant.
template <typename T>
mlir::arith::ConstantOp CreateConst(mlir::ImplicitLocOpBuilder b,
Expand Down Expand Up @@ -454,8 +494,23 @@ absl::StatusOr<Value> EmitElementwise(ImplicitLocOpBuilder& b,

absl::StatusOr<Value> EmitConstant(ImplicitLocOpBuilder& b,
const HloInstruction& constant) {
TF_ASSIGN_OR_RETURN(auto result, triton::EmitConstant(b, constant));
return result.UnwrapScalar();
CHECK_EQ(constant.opcode(), HloOpcode::kConstant);
CHECK(ShapeUtil::IsEffectiveScalar(constant.shape()));

TF_ASSIGN_OR_RETURN(Type ty, TritonType(b, constant.shape().element_type()));

if (constant.shape().element_type() == U64) {
TF_ASSIGN_OR_RETURN(Literal converted, constant.literal().Convert(U64));
return CreateConst(b, ty, converted.GetFirstElement<uint64_t>());
}

if (constant.shape().IsInteger()) {
TF_ASSIGN_OR_RETURN(Literal converted, constant.literal().Convert(S64));
return CreateConst(b, ty, converted.GetFirstElement<int64_t>());
}

TF_ASSIGN_OR_RETURN(Literal converted, constant.literal().Convert(F64));
return CreateConst(b, ty, converted.GetFirstElement<double>());
}

// Emit sequence of operations for unpacking 2xi4 -> i8.
Expand Down
4 changes: 1 addition & 3 deletions xla/service/gpu/fusions/triton/triton_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ bool IsTritonSupportedDataType(PrimitiveType type,
const se::GpuComputeCapability& gpu_version) {
switch (type) {
case PRED:
case S4:
case S8:
case S16:
case S32:
Expand Down Expand Up @@ -143,8 +142,7 @@ CodegenDecision IsTritonSupportedConversion(
}

if (IsTritonSupportedDataType(input, gpu_version) &&
(IsTritonSupportedDataType(output, gpu_version) ||
output == PrimitiveType::S4)) {
IsTritonSupportedDataType(output, gpu_version)) {
return CodegenDecision::Allow();
}

Expand Down

0 comments on commit 4663f04

Please sign in to comment.