Skip to content

Commit

Permalink
Allow to pass attributes when mapping from Mhlo to Scalar.
Browse files Browse the repository at this point in the history
This will be used in a future change, for now it is a non-functional change.

PiperOrigin-RevId: 689349418
  • Loading branch information
akuegel authored and Google-ML-Automation committed Oct 24, 2024
1 parent 9eb9710 commit 8a6a64a
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ struct ScalarHloToArithmeticPattern : public OpConversionPattern<OpTy> {
rewriter.create<tensor::ExtractOp>(loc, operand, ValueRange()));
}
Value scalarResult = mhlo::MhloOpToStdScalarOp::mapOp(
op, resultTy->getElementType(), operands, &rewriter);
op, resultTy->getElementType(), operands, /*attributes=*/std::nullopt,
&rewriter);
if (!scalarResult) return failure();
rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, *resultTy,
scalarResult);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <algorithm>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <utility>

Expand Down Expand Up @@ -103,7 +104,8 @@ class MhloElementwiseConverter : public OpRewritePattern<OpTy> {
}

Value scalarOp = mhlo::MhloOpToStdScalarOp::mapOp(
op, resultTy.getElementType(), extracts, &rewriter);
op, resultTy.getElementType(), extracts, /*attributes=*/std::nullopt,
&rewriter);
operands.push_back(scalarOp);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<MhloOp> {
auto rhs = rewriter.create<memref::LoadOp>(loc, mhloOp.rhs());
Value opResult = mhlo::MhloOpToStdScalarOp::mapOp(
mhloOp, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter);
/*attributes=*/std::nullopt, &rewriter);
rewriter.create<memref::StoreOp>(loc, opResult, mhloOp.out());
rewriter.eraseOp(mhloOp);
return success();
Expand Down Expand Up @@ -1512,7 +1512,7 @@ class IotaConverter : public OpConversionPattern<OpTy> {
indexOp);
castOp = mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp(
nestedLoc, targetElementType, resultElementType, castOp.getType(),
{castOp}, &nestedBuilder);
{castOp}, /*attributes=*/std::nullopt, &nestedBuilder);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
},
linalg::getPrunedAttributeList(iotaOp));
Expand Down Expand Up @@ -1548,7 +1548,8 @@ class IotaToMapConverter : public OpConversionPattern<OpTy> {
nestedLoc, nestedBuilder.getI64Type(), index);
Value result = mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp(
nestedLoc, targetElementType, resultTy.getElementType(),
index.getType(), {ValueRange{index}}, &nestedBuilder);
index.getType(), {ValueRange{index}}, /*attributes=*/std::nullopt,
&nestedBuilder);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, ValueRange{result});
},
linalg::getPrunedAttributeList(iotaOp));
Expand Down Expand Up @@ -4369,7 +4370,8 @@ class PointwiseToLinalgMapConverter : public OpConversionPattern<OpTy> {
[&](OpBuilder& b, Location loc, ValueRange args) {
Value innerResult = mhlo::MhloOpToStdScalarOp::mapOp(
op, getElementTypeOrSelf(emptyTensor),
interleaveScalarAndBlockArgs(scalarInputs, args), &b);
interleaveScalarAndBlockArgs(scalarInputs, args),
/*attributes=*/std::nullopt, &b);
b.create<linalg::YieldOp>(loc, innerResult);
},
linalg::getPrunedAttributeList(op));
Expand Down
310 changes: 157 additions & 153 deletions xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
auto argvec = llvm::to_vector<2>(args.take_front(inputs.size()));
auto semiring = preSparsify(op, argvec, innerResultTy, &rewriter);
Value innerResult = mhlo::MhloOpToStdScalarOp::mapOp(
op, innerResultTy, argvec, &rewriter);
op, innerResultTy, argvec, /*attributes=*/std::nullopt,
&rewriter);
if (innerResult == nullptr) {
failed = true;
} else {
Expand Down
9 changes: 5 additions & 4 deletions xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ SmallVector<Value, 1> MapHloOp(mlir::Type result_type,
Value result = mhlo::MhloOpToStdScalarOp::mapOpOfType<MhloOp>(
b.getLoc(), result_type, arg_types,
typename MhloOp::Adaptor(args, std::forward<ExtraArgs>(extra_args)...),
&b);
/*attributes=*/std::nullopt, &b);
if (result.getType().isInteger(1)) {
result = b.create<mlir::arith::ExtUIOp>(b.getI8Type(), result);
}
Expand Down Expand Up @@ -854,7 +854,7 @@ absl::StatusOr<SmallVector<Value, 1>> EmitConvert(
}
auto out = mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp(
builder.getLoc(), result_type_with_sign, result_element_type, arg_types,
operands, &builder);
operands, /*attributes=*/std::nullopt, &builder);
if (auto int_ty = mlir::dyn_cast<IntegerType>(out.getType())) {
auto in = operands[0];
if (auto float_ty = mlir::dyn_cast<FloatType>(in.getType())) {
Expand Down Expand Up @@ -919,7 +919,7 @@ absl::StatusOr<SmallVector<Value, 1>> EmitIota(const HloInstruction* instr,
index = builder.create<arith::IndexCastUIOp>(index_type, index);
return {{mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp(
builder.getLoc(), result_type_with_sign, result_element_type,
{index_type}, {index}, &builder)}};
{index_type}, {index}, /*attributes=*/std::nullopt, &builder)}};
}

absl::StatusOr<SmallVector<Value, 1>> EmitCompare(
Expand All @@ -934,7 +934,8 @@ absl::StatusOr<SmallVector<Value, 1>> EmitCompare(
auto result_types = llvm::to_vector(mlir::TypeRange{builder.getI1Type()});
auto i1 = mhlo::MhloOpToStdScalarOp::mapOpOfType<mhlo::CompareOp>(
builder.getLoc(), result_types, arg_types,
mhlo::CompareOp::Adaptor(operands, nullptr, properties), &builder);
mhlo::CompareOp::Adaptor(operands, nullptr, properties),
/*attributes=*/std::nullopt, &builder);
return {{builder.create<mlir::arith::ExtUIOp>(builder.getI8Type(), i1)
.getResult()}};
}
Expand Down
4 changes: 3 additions & 1 deletion xla/service/gpu/fusions/transforms/expand_float_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#include <cassert>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>

#include "absl/log/check.h"
Expand Down Expand Up @@ -218,7 +219,8 @@ Value EmitReducePrecision(Value value, int exponent_bits, int mantissa_bits,
return mlir::mhlo::MhloOpToStdScalarOp::mapOpOfType<
mlir::mhlo::ReducePrecisionOp>(
b.getLoc(), value.getType(), {value.getType()},
mlir::mhlo::ReducePrecisionOp::Adaptor(value, nullptr, properties), &b);
mlir::mhlo::ReducePrecisionOp::Adaptor(value, nullptr, properties),
/*attributes=*/std::nullopt, &b);
}

Value EmitF16ToF8e5m2(Value in, mlir::ImplicitLocOpBuilder& b) {
Expand Down

0 comments on commit 8a6a64a

Please sign in to comment.