Skip to content

Commit

Permalink
[HLO] Factor out Literal to DenseElements conversion
Browse files Browse the repository at this point in the history
This is a useful utility for PJRT plugins that want to operate on MLIR datatypes, so better to export the method instead of allowing duplicated logic everywhere.

PiperOrigin-RevId: 688276008
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Oct 24, 2024
1 parent eb3a2b6 commit 0b86ada
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 52 deletions.
21 changes: 21 additions & 0 deletions xla/hlo/translate/mhlo_to_hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,26 @@ cc_library(
],
)

cc_library(
name = "literal_exporter",
srcs = ["literal_exporter.cc"],
hdrs = ["literal_exporter.h"],
deps = [
":type_to_shape",
"//xla:array",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@tsl//tsl/platform:errors",
],
)

cc_library(
name = "location_exporter",
srcs = ["location_exporter.cc"],
Expand Down Expand Up @@ -103,6 +123,7 @@ cc_library(
deps = [
":attribute_exporter",
":layout_util",
":literal_exporter",
":location_exporter",
":module_attributes_exporter",
":operator_writer_inc",
Expand Down
90 changes: 90 additions & 0 deletions xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/hlo/translate/mhlo_to_hlo/literal_exporter.h"

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "llvm/ADT/APInt.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
#include "xla/array.h"
#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h"
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/primitive_util.h"
#include "xla/shape.h"

namespace mlir {
namespace mhlo {

template <typename T>
xla::Array<T> ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) {
constexpr xla::PrimitiveType type =
xla::primitive_util::NativeToPrimitiveType<T>();
xla::Shape shape = xla::TypeToShape(dense_attr.getType());
xla::Array<T> array(shape.dimensions());
if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) {
array.SetValues(dense_attr.getValues<T>());
} else {
// The only way to get subbyte integers from getValues() is to get them as
// APInts.
auto values = dense_attr.getValues<llvm::APInt>();
for (int i = 0; i < values.size(); i++) {
if constexpr (xla::primitive_util::IsUnsignedIntegralType(type)) {
array.data()[i] = T{values[i].getZExtValue()};
} else {
static_assert(xla::primitive_util::IsSignedIntegralType(type));
array.data()[i] = T{values[i].getSExtValue()};
}
}
}
return array;
}

absl::StatusOr<xla::Literal> CreateLiteralFromAttribute(mlir::ElementsAttr attr,
xla::Layout layout) {
auto dense_attr = mlir::dyn_cast<mlir::DenseElementsAttr>(attr);
if (!dense_attr)
return absl::UnimplementedError("Only dense elements attr are supported");

xla::Shape shape = xla::TypeToShape(dense_attr.getType());

return xla::primitive_util::PrimitiveTypeSwitch<absl::StatusOr<xla::Literal>>(
[&](auto primitive_type_constant) -> absl::StatusOr<xla::Literal> {
if constexpr (xla::primitive_util::IsArrayType(
primitive_type_constant)) {
using cpp_type =
xla::primitive_util::NativeTypeOf<primitive_type_constant>;
xla::Array<cpp_type> source_data =
ArrayFromDenseElementsAttr<cpp_type>(dense_attr);
if (layout.minor_to_major().empty()) {
return xla::LiteralUtil::CreateFromArray(source_data);
}
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data,
layout);
}
return absl::InternalError(absl::StrCat( // NOLINT
"Unsupported type: ",
xla::PrimitiveType_Name(shape.element_type())));
},
shape.element_type());
}

} // namespace mhlo
} // namespace mlir
33 changes: 33 additions & 0 deletions xla/hlo/translate/mhlo_to_hlo/literal_exporter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_LITERAL_EXPORTER_H_
#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_LITERAL_EXPORTER_H_

#include "absl/status/statusor.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "xla/layout.h"
#include "xla/literal.h"

namespace mlir {
namespace mhlo {

absl::StatusOr<xla::Literal> CreateLiteralFromAttribute(mlir::ElementsAttr attr,
xla::Layout layout);

} // namespace mhlo
} // namespace mlir

#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_LITERAL_EXPORTER_H_
56 changes: 4 additions & 52 deletions xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ limitations under the License.
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h"
#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h"
#include "xla/hlo/translate/mhlo_to_hlo/literal_exporter.h"
#include "xla/hlo/translate/mhlo_to_hlo/location_exporter.h"
#include "xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h"
#include "xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h"
Expand Down Expand Up @@ -214,56 +215,6 @@ bool IsBoundedOrStatic(mlir::Type ty) {
return true;
}

template <typename T>
xla::Array<T> ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) {
constexpr xla::PrimitiveType type =
xla::primitive_util::NativeToPrimitiveType<T>();
xla::Shape shape = xla::TypeToShape(dense_attr.getType());
xla::Array<T> array(shape.dimensions());
if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) {
array.SetValues(dense_attr.getValues<T>());
} else {
// The only way to get subbyte integers from getValues() is to get them as
// APInts.
auto values = dense_attr.getValues<llvm::APInt>();
for (int i = 0; i < values.size(); i++) {
if constexpr (xla::primitive_util::IsUnsignedIntegralType(type)) {
array.data()[i] = T{values[i].getZExtValue()};
} else {
static_assert(xla::primitive_util::IsSignedIntegralType(type));
array.data()[i] = T{values[i].getSExtValue()};
}
}
}
return array;
}

absl::StatusOr<xla::Literal> CreateArrayLiteralFromAttr(mlir::ElementsAttr attr,
xla::Layout layout) {
auto dense_attr = mlir::dyn_cast<mlir::DenseElementsAttr>(attr);
if (!dense_attr)
return tsl::errors::Unimplemented("Only dense elements attr are supported");

xla::Shape shape = xla::TypeToShape(dense_attr.getType());

return xla::primitive_util::PrimitiveTypeSwitch<absl::StatusOr<xla::Literal>>(
[&](auto primitive_type_constant) -> absl::StatusOr<xla::Literal> {
if constexpr (xla::primitive_util::IsArrayType(
primitive_type_constant)) {
using cpp_type =
xla::primitive_util::NativeTypeOf<primitive_type_constant>;
xla::Array<cpp_type> source_data =
ArrayFromDenseElementsAttr<cpp_type>(dense_attr);
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data,
layout);
}
return tsl::errors::Internal(absl::StrCat( // NOLINT
"Unsupported type: ",
xla::PrimitiveType_Name(shape.element_type())));
},
shape.element_type());
}

// Convert APInt into an int.
// TODO(hpucha): This should be consolidated into a general place.
static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
Expand Down Expand Up @@ -2268,7 +2219,7 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) {
const xla::Literal* literal_ptr = nullptr;
auto literal_attr = op->getAttrOfType<DenseElementsAttr>(kMhloLiteral);
if (literal_attr) {
literal = CreateArrayLiteralFromAttr(literal_attr, {});
literal = mhlo::CreateLiteralFromAttribute(literal_attr, {});
if (!literal.ok()) return failure();
literal_ptr = &*literal;
}
Expand Down Expand Up @@ -3312,7 +3263,8 @@ LogicalResult ConvertToHloModule::LowerConstant(
mlir::FailureOr<xla::Shape> shape_or = ExtractXlaShape(inst);
if (failed(shape_or)) return failure();

auto literal_or = CreateArrayLiteralFromAttr(const_attr, shape_or->layout());
auto literal_or =
mhlo::CreateLiteralFromAttribute(const_attr, shape_or->layout());
if (!literal_or.ok()) return inst->emitError(literal_or.status().ToString());

xla::XlaScopedShardingAssignment scoped_sharding(
Expand Down

0 comments on commit 0b86ada

Please sign in to comment.