From faf28e5df8fa7caeed347fb694b5fcfac301d2af Mon Sep 17 00:00:00 2001 From: Robert Esclapez Date: Tue, 12 Sep 2023 17:40:43 +0200 Subject: [PATCH] Add `getNonSplatRawData` method to `DenseIntOrFPElementsAttr` for fast retrieval of internal data (#78) --- mlir/include/mlir/IR/BuiltinAttributes.td | 18 +++++++ .../Dialect/Tosa/Transforms/TosaFolders.cpp | 6 +-- mlir/unittests/IR/AttributeTest.cpp | 54 +++++++++++++++++++ 3 files changed, 74 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index 075eee456a7b58..bcbaf1da940b5d 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -348,6 +348,24 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< int64_t dataEltSize, bool isInt, bool isSigned); public: + /// Returns the internal buffer of a non-splatted DenseIntOrFPElementAttr. + /// Provided storage_t must be unsigned integer and match the bit-width of + /// the element type of the DenseIntOrFPElementAttr. \returns an + /// ArrayRef over the internal data + template + llvm::ArrayRef getNonSplatRawData() const { + assert(!isSplat() && "DenseElementAttr must not be splatted"); + assert(getElementType().getIntOrFloatBitWidth() == + sizeof(storage_t) * /*bits per byte*/ 8); + + static_assert(std::numeric_limits::is_integer && + !std::numeric_limits::is_signed, + "storage type must be integer and unsigned"); + + return llvm::ArrayRef( + reinterpret_cast(getRawData().data()), + getNumElements()); + } }]; let genAccessors = 0; let genStorageClass = 0; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index d32999ec917329..fd8db86e723cd6 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -225,10 +225,8 @@ template DenseElementsAttr transposeTypeRaw(DenseElementsAttr attr, ShapedType inputType, ShapedType outputType, llvm::ArrayRef permValues) { - - ArrayRef inputValues( - reinterpret_cast(attr.getRawData().data()), - attr.getNumElements()); + ArrayRef inputValues = + cast(attr).getNonSplatRawData(); SmallVector outputValues; outputValues.resize_for_overwrite(inputType.getNumElements()); diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index f01cc026b72cc2..057fc42cf866c3 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -446,3 +446,57 @@ TEST(SubElementTest, Nested) { {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr})); } } // namespace + +//===----------------------------------------------------------------------===// +// NonSplattedDenseElementAttr +//===----------------------------------------------------------------------===// + +namespace { +TEST(NonSplattedDenseElementAttrTest, GetNonSplatRawDataF32) { + constexpr std::size_t numberOfElements = 6; + static constexpr std::array rawValues = { + 0., 2., 4., 8., 3.1, 10.4}; + + mlir::MLIRContext context; + mlir::OpBuilder b(&context); + + llvm::SmallVector mlirValues; + llvm::transform(rawValues, std::back_inserter(mlirValues), + [&](float v) { return b.getFloatAttr(b.getF32Type(), v); }); + + llvm::ArrayRef expected( + reinterpret_cast(rawValues.data()), rawValues.size()); + + auto values = mlir::DenseElementsAttr::get( + mlir::RankedTensorType::get({numberOfElements}, b.getF32Type()), + mlirValues); + + EXPECT_EQ(mlir::cast(values) + .getNonSplatRawData(), + expected); +} + +TEST(NonSplattedDenseElementAttrTest, GetNonSplatRawDataI16) { + constexpr std::size_t numberOfElements = 6; + static constexpr std::array rawValues = { + 12, 5723, 23, 2, 634, 321}; + + mlir::MLIRContext context; + mlir::OpBuilder b(&context); + + llvm::SmallVector mlirValues; + llvm::transform(rawValues, std::back_inserter(mlirValues), + [&](int16_t v) { return b.getI16IntegerAttr(v); }); + + llvm::ArrayRef expected( + reinterpret_cast(rawValues.data()), rawValues.size()); + + auto values = mlir::DenseElementsAttr::get( + mlir::RankedTensorType::get({numberOfElements}, b.getI16Type()), + mlirValues); + + EXPECT_EQ(mlir::cast(values) + .getNonSplatRawData(), + expected); +} +} // namespace