From 0b86ada1dfb94818923468aa551695da00e6531d Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 21 Oct 2024 14:34:50 -0700 Subject: [PATCH] [HLO] Factor out Literal to DenseElements conversion This is a useful utility for PJRT plugins that want to operate on MLIR datatypes, so better to export the method instead of allowing duplicated logic everywhere. PiperOrigin-RevId: 688276008 --- xla/hlo/translate/mhlo_to_hlo/BUILD | 21 +++++ .../translate/mhlo_to_hlo/literal_exporter.cc | 90 +++++++++++++++++++ .../translate/mhlo_to_hlo/literal_exporter.h | 33 +++++++ .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 56 +----------- 4 files changed, 148 insertions(+), 52 deletions(-) create mode 100644 xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc create mode 100644 xla/hlo/translate/mhlo_to_hlo/literal_exporter.h diff --git a/xla/hlo/translate/mhlo_to_hlo/BUILD b/xla/hlo/translate/mhlo_to_hlo/BUILD index 3a6f48a8dcb73..49c036bf8251b 100644 --- a/xla/hlo/translate/mhlo_to_hlo/BUILD +++ b/xla/hlo/translate/mhlo_to_hlo/BUILD @@ -53,6 +53,26 @@ cc_library( ], ) +cc_library( + name = "literal_exporter", + srcs = ["literal_exporter.cc"], + hdrs = ["literal_exporter.h"], + deps = [ + ":type_to_shape", + "//xla:array", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@tsl//tsl/platform:errors", + ], +) + cc_library( name = "location_exporter", srcs = ["location_exporter.cc"], @@ -103,6 +123,7 @@ cc_library( deps = [ ":attribute_exporter", ":layout_util", + ":literal_exporter", ":location_exporter", ":module_attributes_exporter", ":operator_writer_inc", diff --git a/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc new file mode 100644 index 0000000000000..821f1487cf88c --- /dev/null +++ b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc @@ -0,0 +1,90 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/translate/mhlo_to_hlo/literal_exporter.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/APInt.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Support/LLVM.h" +#include "xla/array.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/layout.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/primitive_util.h" +#include "xla/shape.h" + +namespace mlir { +namespace mhlo { + +template +xla::Array ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) { + constexpr xla::PrimitiveType type = + xla::primitive_util::NativeToPrimitiveType(); + xla::Shape shape = xla::TypeToShape(dense_attr.getType()); + xla::Array array(shape.dimensions()); + if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) { + array.SetValues(dense_attr.getValues()); + } else { + // The only way to get subbyte integers from getValues() is to get them as + // APInts. + auto values = dense_attr.getValues(); + for (int i = 0; i < values.size(); i++) { + if constexpr (xla::primitive_util::IsUnsignedIntegralType(type)) { + array.data()[i] = T{values[i].getZExtValue()}; + } else { + static_assert(xla::primitive_util::IsSignedIntegralType(type)); + array.data()[i] = T{values[i].getSExtValue()}; + } + } + } + return array; +} + +absl::StatusOr CreateLiteralFromAttribute(mlir::ElementsAttr attr, + xla::Layout layout) { + auto dense_attr = mlir::dyn_cast(attr); + if (!dense_attr) + return absl::UnimplementedError("Only dense elements attr are supported"); + + xla::Shape shape = xla::TypeToShape(dense_attr.getType()); + + return xla::primitive_util::PrimitiveTypeSwitch>( + [&](auto primitive_type_constant) -> absl::StatusOr { + if constexpr (xla::primitive_util::IsArrayType( + primitive_type_constant)) { + using cpp_type = + xla::primitive_util::NativeTypeOf; + xla::Array source_data = + ArrayFromDenseElementsAttr(dense_attr); + if (layout.minor_to_major().empty()) { + return xla::LiteralUtil::CreateFromArray(source_data); + } + return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, + layout); + } + return absl::InternalError(absl::StrCat( // NOLINT + "Unsupported type: ", + xla::PrimitiveType_Name(shape.element_type()))); + }, + shape.element_type()); +} + +} // namespace mhlo +} // namespace mlir diff --git a/xla/hlo/translate/mhlo_to_hlo/literal_exporter.h b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.h new file mode 100644 index 0000000000000..f5cb3c74a2819 --- /dev/null +++ b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_LITERAL_EXPORTER_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_LITERAL_EXPORTER_H_ + +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "xla/layout.h" +#include "xla/literal.h" + +namespace mlir { +namespace mhlo { + +absl::StatusOr CreateLiteralFromAttribute(mlir::ElementsAttr attr, + xla::Layout layout); + +} // namespace mhlo +} // namespace mlir + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_LITERAL_EXPORTER_H_ diff --git a/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 0cd03b39b0481..0807964addc5e 100644 --- a/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -79,6 +79,7 @@ limitations under the License. #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" #include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" +#include "xla/hlo/translate/mhlo_to_hlo/literal_exporter.h" #include "xla/hlo/translate/mhlo_to_hlo/location_exporter.h" #include "xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h" #include "xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h" @@ -214,56 +215,6 @@ bool IsBoundedOrStatic(mlir::Type ty) { return true; } -template -xla::Array ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) { - constexpr xla::PrimitiveType type = - xla::primitive_util::NativeToPrimitiveType(); - xla::Shape shape = xla::TypeToShape(dense_attr.getType()); - xla::Array array(shape.dimensions()); - if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) { - array.SetValues(dense_attr.getValues()); - } else { - // The only way to get subbyte integers from getValues() is to get them as - // APInts. - auto values = dense_attr.getValues(); - for (int i = 0; i < values.size(); i++) { - if constexpr (xla::primitive_util::IsUnsignedIntegralType(type)) { - array.data()[i] = T{values[i].getZExtValue()}; - } else { - static_assert(xla::primitive_util::IsSignedIntegralType(type)); - array.data()[i] = T{values[i].getSExtValue()}; - } - } - } - return array; -} - -absl::StatusOr CreateArrayLiteralFromAttr(mlir::ElementsAttr attr, - xla::Layout layout) { - auto dense_attr = mlir::dyn_cast(attr); - if (!dense_attr) - return tsl::errors::Unimplemented("Only dense elements attr are supported"); - - xla::Shape shape = xla::TypeToShape(dense_attr.getType()); - - return xla::primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> absl::StatusOr { - if constexpr (xla::primitive_util::IsArrayType( - primitive_type_constant)) { - using cpp_type = - xla::primitive_util::NativeTypeOf; - xla::Array source_data = - ArrayFromDenseElementsAttr(dense_attr); - return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, - layout); - } - return tsl::errors::Internal(absl::StrCat( // NOLINT - "Unsupported type: ", - xla::PrimitiveType_Name(shape.element_type()))); - }, - shape.element_type()); -} - // Convert APInt into an int. // TODO(hpucha): This should be consolidated into a general place. static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); } @@ -2268,7 +2219,7 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) { const xla::Literal* literal_ptr = nullptr; auto literal_attr = op->getAttrOfType(kMhloLiteral); if (literal_attr) { - literal = CreateArrayLiteralFromAttr(literal_attr, {}); + literal = mhlo::CreateLiteralFromAttribute(literal_attr, {}); if (!literal.ok()) return failure(); literal_ptr = &*literal; } @@ -3312,7 +3263,8 @@ LogicalResult ConvertToHloModule::LowerConstant( mlir::FailureOr shape_or = ExtractXlaShape(inst); if (failed(shape_or)) return failure(); - auto literal_or = CreateArrayLiteralFromAttr(const_attr, shape_or->layout()); + auto literal_or = + mhlo::CreateLiteralFromAttribute(const_attr, shape_or->layout()); if (!literal_or.ok()) return inst->emitError(literal_or.status().ToString()); xla::XlaScopedShardingAssignment scoped_sharding(