-
Notifications
You must be signed in to change notification settings - Fork 608
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add simple pass to turn dense attributes into dense_resource attribut…
…es. (#14574)
- Loading branch information
1 parent
c35c88e
commit 1791958
Showing
8 changed files
with
323 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
206 changes: 206 additions & 0 deletions
206
compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include <utility> | ||
|
||
#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h" | ||
#include "iree/compiler/Dialect/Util/Transforms/Passes.h" | ||
#include "llvm/ADT/DenseMap.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/BuiltinDialect.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
#define DEBUG_TYPE "iree-util-import-resources" | ||
|
||
namespace mlir::iree_compiler::IREE::Util { | ||
|
||
namespace { | ||
|
||
// TODO: Just use the DenseResourceElementsAttr::get() | ||
// builder once https://reviews.llvm.org/D157064 lands. | ||
class DenseBlobResourceElementsAttr : public DenseResourceElementsAttr { | ||
public: | ||
using DenseResourceElementsAttr::get; | ||
}; | ||
|
||
template <typename ElementType, unsigned numBits = sizeof(ElementType) * 8> | ||
static void copyIntAttrIntoBlob(AsmResourceBlob &blob, | ||
DenseIntElementsAttr attr) { | ||
ArrayRef<ElementType> data = blob.getDataAs<ElementType>(); | ||
MutableArrayRef<ElementType> rwData = MutableArrayRef<ElementType>( | ||
const_cast<ElementType *>(data.data()), data.size()); | ||
ArrayRef<char> rawSrcData = attr.getRawData(); | ||
if (rawSrcData.size() == blob.getData().size()) { | ||
// Memcpy. | ||
std::memcpy(rwData.data(), rawSrcData.data(), rawSrcData.size()); | ||
} else { | ||
// Slow. | ||
size_t index = 0; | ||
for (APInt value : attr.getValues<APInt>()) { | ||
rwData[index++] = value.extractBitsAsZExtValue(numBits, 0); | ||
} | ||
} | ||
} | ||
|
||
template <typename ElementType, unsigned numBits = sizeof(ElementType) * 8> | ||
static void copyFPAttrIntoBlob(AsmResourceBlob &blob, | ||
DenseFPElementsAttr attr) { | ||
ArrayRef<ElementType> data = blob.getDataAs<ElementType>(); | ||
MutableArrayRef<ElementType> rwData = MutableArrayRef<ElementType>( | ||
const_cast<ElementType *>(data.data()), data.size()); | ||
ArrayRef<char> rawSrcData = attr.getRawData(); | ||
if (rawSrcData.size() == blob.getData().size()) { | ||
// Memcpy. | ||
std::memcpy(rwData.data(), rawSrcData.data(), rawSrcData.size()); | ||
} else { | ||
// Slow. | ||
size_t index = 0; | ||
for (APFloat value : attr.getValues<APFloat>()) { | ||
rwData[index++] = | ||
value.bitcastToAPInt().extractBitsAsZExtValue(numBits, 0); | ||
} | ||
} | ||
} | ||
|
||
class ImportResourcesPass : public ImportResourcesBase<ImportResourcesPass> { | ||
public: | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<BuiltinDialect>(); | ||
} | ||
|
||
void runOnOperation() override { | ||
llvm::DenseMap<Attribute, Attribute> replacements; | ||
|
||
getOperation()->walk([&](Operation *op) { | ||
bool updated = false; | ||
SmallVector<NamedAttribute> attrs(op->getAttrs()); | ||
for (auto &attr : attrs) { | ||
if (auto elements = llvm::dyn_cast<ElementsAttr>(attr.getValue())) { | ||
// Already seen? | ||
auto it = replacements.find(elements); | ||
if (it != replacements.end()) { | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< ":: Replacing already encountered attr of " | ||
<< elements.getType() << "\n"); | ||
attr.setValue(it->second); | ||
updated = true; | ||
continue; | ||
} | ||
|
||
// Convert. | ||
if (shouldConvertElements(elements)) { | ||
LLVM_DEBUG(llvm::dbgs() << ":: Converting elements attr of " | ||
<< elements.getType() << "\n"); | ||
if (auto replacement = convertElementsAttr(elements)) { | ||
attr.setValue(replacement); | ||
replacements[elements] = replacement; | ||
updated = true; | ||
} else { | ||
LLVM_DEBUG(llvm::dbgs() << " Failed to convert\n"); | ||
} | ||
} | ||
} | ||
} | ||
if (updated) | ||
op->setAttrs(attrs); | ||
}); | ||
LLVM_DEBUG(llvm::dbgs() << "DONE CONVERTING RESOURCES\n"); | ||
} | ||
|
||
static bool shouldConvertElements(ElementsAttr attr) { | ||
if (llvm::isa<DenseElementsAttr>(attr)) { | ||
// DenseElementsAttr encodes arbitrary dimension | ||
// splats whereas DenseResourceElementsAttr does not. | ||
return !attr.isSplat(); | ||
} | ||
|
||
return false; | ||
} | ||
|
||
static ElementsAttr convertElementsAttr(ElementsAttr elementsAttr) { | ||
auto st = llvm::cast<ShapedType>(elementsAttr.getType()); | ||
auto elementType = st.getElementType(); | ||
auto numElements = elementsAttr.getNumElements(); | ||
auto bitWidth = elementType.getIntOrFloatBitWidth(); | ||
AsmResourceBlob blob; | ||
if (auto attr = llvm::dyn_cast<DenseIntElementsAttr>(elementsAttr)) { | ||
switch (bitWidth) { | ||
case 1: | ||
blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64, | ||
/*dataIsMutable=*/true); | ||
copyIntAttrIntoBlob<uint8_t, /*numBits=*/1>(blob, attr); | ||
return DenseBlobResourceElementsAttr::get(st, "dense_elements_i1", | ||
std::move(blob)); | ||
case 8: | ||
blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64, | ||
/*dataIsMutable=*/true); | ||
copyIntAttrIntoBlob<uint8_t>(blob, attr); | ||
return DenseBlobResourceElementsAttr::get(st, "dense_elements_i8", | ||
std::move(blob)); | ||
case 16: | ||
blob = HeapAsmResourceBlob::allocate(2 * numElements, /*align=*/64, | ||
/*dataIsMutable=*/true); | ||
copyIntAttrIntoBlob<uint16_t>(blob, attr); | ||
return DenseBlobResourceElementsAttr::get(st, "dense_elements_i16", | ||
std::move(blob)); | ||
case 32: | ||
blob = HeapAsmResourceBlob::allocate(4 * numElements, /*align=*/64, | ||
/*dataIsMutable=*/true); | ||
copyIntAttrIntoBlob<uint32_t>(blob, attr); | ||
return DenseBlobResourceElementsAttr::get(st, "dense_elements_i32", | ||
std::move(blob)); | ||
case 64: | ||
blob = HeapAsmResourceBlob::allocate(8 * numElements, /*align=*/64, | ||
/*dataIsMutable=*/true); | ||
copyIntAttrIntoBlob<uint64_t>(blob, attr); | ||
return DenseBlobResourceElementsAttr::get(st, "dense_elements_i64", | ||
std::move(blob)); | ||
default: | ||
return {}; | ||
} | ||
} else if (auto attr = llvm::dyn_cast<DenseFPElementsAttr>(elementsAttr)) { | ||
AsmResourceBlob blob; | ||
switch (bitWidth) { | ||
case 8: | ||
blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64, | ||
/*dataIsMutable=*/true); | ||
copyFPAttrIntoBlob<uint8_t>(blob, attr); | ||
return DenseBlobResourceElementsAttr::get(st, "dense_elements_f8", | ||
std::move(blob)); | ||
case 16: | ||
blob = HeapAsmResourceBlob::allocate(2 * numElements, /*align=*/64, | ||
/*dataIsMutable=*/true); | ||
copyFPAttrIntoBlob<uint16_t>(blob, attr); | ||
return DenseBlobResourceElementsAttr::get(st, "dense_elements_f16", | ||
std::move(blob)); | ||
case 32: | ||
blob = HeapAsmResourceBlob::allocate(4 * numElements, /*align=*/64, | ||
/*dataIsMutable=*/true); | ||
copyFPAttrIntoBlob<uint32_t>(blob, attr); | ||
return DenseBlobResourceElementsAttr::get(st, "dense_elements_f32", | ||
std::move(blob)); | ||
case 64: | ||
blob = HeapAsmResourceBlob::allocate(8 * numElements, /*align=*/64, | ||
/*dataIsMutable=*/true); | ||
copyFPAttrIntoBlob<uint64_t>(blob, attr); | ||
return DenseBlobResourceElementsAttr::get(st, "dense_elements_f64", | ||
std::move(blob)); | ||
default: | ||
return {}; | ||
} | ||
} | ||
return {}; | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<void>> createImportResourcesPass() { | ||
return std::make_unique<ImportResourcesPass>(); | ||
} | ||
|
||
} // namespace mlir::iree_compiler::IREE::Util |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
compiler/src/iree/compiler/Dialect/Util/Transforms/test/import_resources.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// RUN: iree-opt --split-input-file --iree-util-import-resources %s | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @constant_splat_i64 | ||
func.func @constant_splat_i64() -> tensor<4xi64> { | ||
// Splats should not convert. | ||
// CHECK-NEXT: constant dense<123> | ||
%c123 = arith.constant dense<123> : tensor<4xi64> | ||
return %c123 : tensor<4xi64> | ||
} | ||
|
||
// ----- | ||
// CHECK-LABEL: func.func @dense_i1 | ||
func.func @dense_i1() -> tensor<4xi1> { | ||
// CHECK: dense_resource<dense_elements_i1> | ||
%c123 = arith.constant dense<[true, false, false, true]> : tensor<4xi1> | ||
return %c123 : tensor<4xi1> | ||
} | ||
|
||
// CHECK: dense_elements_i1: "0x4000000001000001" | ||
|
||
// ----- | ||
// CHECK-LABEL: func.func @dense_i8 | ||
func.func @dense_i8() -> tensor<4xi8> { | ||
// CHECK: dense_resource<dense_elements_i8> | ||
%c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi8> | ||
return %c123 : tensor<4xi8> | ||
} | ||
|
||
// CHECK: dense_elements_i8: "0x400000000102037F" | ||
|
||
// ----- | ||
// CHECK-LABEL: func.func @dense_i16 | ||
func.func @dense_i16() -> tensor<4xi16> { | ||
// CHECK: dense_resource<dense_elements_i16> | ||
%c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi16> | ||
return %c123 : tensor<4xi16> | ||
} | ||
|
||
// CHECK: dense_elements_i16: "0x400000000100020003007F00" | ||
|
||
// ----- | ||
// CHECK-LABEL: func.func @dense_i32 | ||
func.func @dense_i32() -> tensor<4xi32> { | ||
// CHECK: dense_resource<dense_elements_i32> | ||
%c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi32> | ||
return %c123 : tensor<4xi32> | ||
} | ||
|
||
// CHECK: dense_elements_i32: "0x400000000100000002000000030000007F000000" | ||
|
||
// ----- | ||
// CHECK-LABEL: func.func @dense_i64 | ||
func.func @dense_i64() -> tensor<4xi64> { | ||
// CHECK: dense_resource<dense_elements_i64> | ||
%c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi64> | ||
return %c123 : tensor<4xi64> | ||
} | ||
|
||
// CHECK: dense_elements_i64: "0x400000000100000000000000020000000000000003000000000000007F00000000000000" | ||
|
||
// ----- | ||
// CHECK-LABEL: func.func @dense_f16 | ||
func.func @dense_f16() -> tensor<4xf16> { | ||
// CHECK: dense_resource<dense_elements_f16> | ||
%c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf16> | ||
return %c123 : tensor<4xf16> | ||
} | ||
|
||
// CHECK: dense_elements_f16: "0x40000000663C66409A420000" | ||
|
||
// ----- | ||
// CHECK-LABEL: func.func @dense_f32 | ||
func.func @dense_f32() -> tensor<4xf32> { | ||
// CHECK: dense_resource<dense_elements_f32> | ||
%c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf32> | ||
return %c123 : tensor<4xf32> | ||
} | ||
|
||
// CHECK: dense_elements_f32: "0x40000000CDCC8C3FCDCC0C403333534000000000" | ||
|
||
// ----- | ||
// CHECK-LABEL: func.func @dense_f64 | ||
func.func @dense_f64() -> tensor<4xf64> { | ||
// CHECK: dense_resource<dense_elements_f64> | ||
%c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf64> | ||
return %c123 : tensor<4xf64> | ||
} | ||
|
||
// CHECK: dense_elements_f64: "0x400000009A9999999999F13F9A999999999901406666666666660A400000000000000000" |