Skip to content

Commit

Permalink
Add getNonSplatRawData method to DenseIntOrFPElementsAttr for fas…
Browse files Browse the repository at this point in the history
…t retrieval of internal data (#78)
  • Loading branch information
roberteg16 authored Sep 12, 2023
1 parent 22a1f58 commit faf28e5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 4 deletions.
18 changes: 18 additions & 0 deletions mlir/include/mlir/IR/BuiltinAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,24 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
int64_t dataEltSize, bool isInt,
bool isSigned);
public:
/// Returns the internal buffer of a non-splatted DenseIntOrFPElementAttr.
/// Provided storage_t must be unsigned integer and match the bit-width of
/// the element type of the DenseIntOrFPElementAttr. \returns an
/// ArrayRef<storage_t> over the internal data
template <typename storage_t>
llvm::ArrayRef<storage_t> getNonSplatRawData() const {
assert(!isSplat() && "DenseElementAttr must not be splatted");
assert(getElementType().getIntOrFloatBitWidth() ==
sizeof(storage_t) * /*bits per byte*/ 8);

static_assert(std::numeric_limits<storage_t>::is_integer &&
!std::numeric_limits<storage_t>::is_signed,
"storage type must be integer and unsigned");

return llvm::ArrayRef<storage_t>(
reinterpret_cast<const storage_t *>(getRawData().data()),
getNumElements());
}
}];
let genAccessors = 0;
let genStorageClass = 0;
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,8 @@ template <typename BaseType>
DenseElementsAttr transposeTypeRaw(DenseElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {

ArrayRef inputValues(
reinterpret_cast<const BaseType *>(attr.getRawData().data()),
attr.getNumElements());
ArrayRef<BaseType> inputValues =
cast<DenseIntOrFPElementsAttr>(attr).getNonSplatRawData<BaseType>();

SmallVector<BaseType> outputValues;
outputValues.resize_for_overwrite(inputType.getNumElements());
Expand Down
54 changes: 54 additions & 0 deletions mlir/unittests/IR/AttributeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,57 @@ TEST(SubElementTest, Nested) {
{strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
}
} // namespace

//===----------------------------------------------------------------------===//
// NonSplattedDenseElementAttr
//===----------------------------------------------------------------------===//

namespace {
TEST(NonSplattedDenseElementAttrTest, GetNonSplatRawDataF32) {
constexpr std::size_t numberOfElements = 6;
static constexpr std::array<float, numberOfElements> rawValues = {
0., 2., 4., 8., 3.1, 10.4};

mlir::MLIRContext context;
mlir::OpBuilder b(&context);

llvm::SmallVector<mlir::Attribute> mlirValues;
llvm::transform(rawValues, std::back_inserter(mlirValues),
[&](float v) { return b.getFloatAttr(b.getF32Type(), v); });

llvm::ArrayRef<uint32_t> expected(
reinterpret_cast<const uint32_t *>(rawValues.data()), rawValues.size());

auto values = mlir::DenseElementsAttr::get(
mlir::RankedTensorType::get({numberOfElements}, b.getF32Type()),
mlirValues);

EXPECT_EQ(mlir::cast<mlir::DenseIntOrFPElementsAttr>(values)
.getNonSplatRawData<uint32_t>(),
expected);
}

TEST(NonSplattedDenseElementAttrTest, GetNonSplatRawDataI16) {
constexpr std::size_t numberOfElements = 6;
static constexpr std::array<int16_t, numberOfElements> rawValues = {
12, 5723, 23, 2, 634, 321};

mlir::MLIRContext context;
mlir::OpBuilder b(&context);

llvm::SmallVector<mlir::Attribute> mlirValues;
llvm::transform(rawValues, std::back_inserter(mlirValues),
[&](int16_t v) { return b.getI16IntegerAttr(v); });

llvm::ArrayRef<uint16_t> expected(
reinterpret_cast<const uint16_t *>(rawValues.data()), rawValues.size());

auto values = mlir::DenseElementsAttr::get(
mlir::RankedTensorType::get({numberOfElements}, b.getI16Type()),
mlirValues);

EXPECT_EQ(mlir::cast<mlir::DenseIntOrFPElementsAttr>(values)
.getNonSplatRawData<uint16_t>(),
expected);
}
} // namespace

0 comments on commit faf28e5

Please sign in to comment.