Skip to content

Commit

Permalink
Add simple pass to turn dense attributes into dense_resource attribut…
Browse files Browse the repository at this point in the history
…es. (#14574)
  • Loading branch information
stellaraccident authored Aug 4, 2023
1 parent c35c88e commit 1791958
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_compiler_cc_library(
"FuseGlobals.cpp",
"HoistIntoGlobals.cpp",
"IPO.cpp",
"ImportResources.cpp",
"PassDetail.h",
"Passes.cpp",
"Patterns.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ iree_cc_library(
"FuseGlobals.cpp"
"HoistIntoGlobals.cpp"
"IPO.cpp"
"ImportResources.cpp"
"PassDetail.h"
"Passes.cpp"
"Patterns.cpp"
Expand Down
206 changes: 206 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <utility>

#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Pass/Pass.h"

#define DEBUG_TYPE "iree-util-import-resources"

namespace mlir::iree_compiler::IREE::Util {

namespace {

// TODO: Just use the DenseResourceElementsAttr::get()
// builder once https://reviews.llvm.org/D157064 lands.
class DenseBlobResourceElementsAttr : public DenseResourceElementsAttr {
public:
using DenseResourceElementsAttr::get;
};

template <typename ElementType, unsigned numBits = sizeof(ElementType) * 8>
static void copyIntAttrIntoBlob(AsmResourceBlob &blob,
DenseIntElementsAttr attr) {
ArrayRef<ElementType> data = blob.getDataAs<ElementType>();
MutableArrayRef<ElementType> rwData = MutableArrayRef<ElementType>(
const_cast<ElementType *>(data.data()), data.size());
ArrayRef<char> rawSrcData = attr.getRawData();
if (rawSrcData.size() == blob.getData().size()) {
// Memcpy.
std::memcpy(rwData.data(), rawSrcData.data(), rawSrcData.size());
} else {
// Slow.
size_t index = 0;
for (APInt value : attr.getValues<APInt>()) {
rwData[index++] = value.extractBitsAsZExtValue(numBits, 0);
}
}
}

template <typename ElementType, unsigned numBits = sizeof(ElementType) * 8>
static void copyFPAttrIntoBlob(AsmResourceBlob &blob,
DenseFPElementsAttr attr) {
ArrayRef<ElementType> data = blob.getDataAs<ElementType>();
MutableArrayRef<ElementType> rwData = MutableArrayRef<ElementType>(
const_cast<ElementType *>(data.data()), data.size());
ArrayRef<char> rawSrcData = attr.getRawData();
if (rawSrcData.size() == blob.getData().size()) {
// Memcpy.
std::memcpy(rwData.data(), rawSrcData.data(), rawSrcData.size());
} else {
// Slow.
size_t index = 0;
for (APFloat value : attr.getValues<APFloat>()) {
rwData[index++] =
value.bitcastToAPInt().extractBitsAsZExtValue(numBits, 0);
}
}
}

class ImportResourcesPass : public ImportResourcesBase<ImportResourcesPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<BuiltinDialect>();
}

void runOnOperation() override {
llvm::DenseMap<Attribute, Attribute> replacements;

getOperation()->walk([&](Operation *op) {
bool updated = false;
SmallVector<NamedAttribute> attrs(op->getAttrs());
for (auto &attr : attrs) {
if (auto elements = llvm::dyn_cast<ElementsAttr>(attr.getValue())) {
// Already seen?
auto it = replacements.find(elements);
if (it != replacements.end()) {
LLVM_DEBUG(llvm::dbgs()
<< ":: Replacing already encountered attr of "
<< elements.getType() << "\n");
attr.setValue(it->second);
updated = true;
continue;
}

// Convert.
if (shouldConvertElements(elements)) {
LLVM_DEBUG(llvm::dbgs() << ":: Converting elements attr of "
<< elements.getType() << "\n");
if (auto replacement = convertElementsAttr(elements)) {
attr.setValue(replacement);
replacements[elements] = replacement;
updated = true;
} else {
LLVM_DEBUG(llvm::dbgs() << " Failed to convert\n");
}
}
}
}
if (updated)
op->setAttrs(attrs);
});
LLVM_DEBUG(llvm::dbgs() << "DONE CONVERTING RESOURCES\n");
}

static bool shouldConvertElements(ElementsAttr attr) {
if (llvm::isa<DenseElementsAttr>(attr)) {
// DenseElementsAttr encodes arbitrary dimension
// splats whereas DenseResourceElementsAttr does not.
return !attr.isSplat();
}

return false;
}

static ElementsAttr convertElementsAttr(ElementsAttr elementsAttr) {
auto st = llvm::cast<ShapedType>(elementsAttr.getType());
auto elementType = st.getElementType();
auto numElements = elementsAttr.getNumElements();
auto bitWidth = elementType.getIntOrFloatBitWidth();
AsmResourceBlob blob;
if (auto attr = llvm::dyn_cast<DenseIntElementsAttr>(elementsAttr)) {
switch (bitWidth) {
case 1:
blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64,
/*dataIsMutable=*/true);
copyIntAttrIntoBlob<uint8_t, /*numBits=*/1>(blob, attr);
return DenseBlobResourceElementsAttr::get(st, "dense_elements_i1",
std::move(blob));
case 8:
blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64,
/*dataIsMutable=*/true);
copyIntAttrIntoBlob<uint8_t>(blob, attr);
return DenseBlobResourceElementsAttr::get(st, "dense_elements_i8",
std::move(blob));
case 16:
blob = HeapAsmResourceBlob::allocate(2 * numElements, /*align=*/64,
/*dataIsMutable=*/true);
copyIntAttrIntoBlob<uint16_t>(blob, attr);
return DenseBlobResourceElementsAttr::get(st, "dense_elements_i16",
std::move(blob));
case 32:
blob = HeapAsmResourceBlob::allocate(4 * numElements, /*align=*/64,
/*dataIsMutable=*/true);
copyIntAttrIntoBlob<uint32_t>(blob, attr);
return DenseBlobResourceElementsAttr::get(st, "dense_elements_i32",
std::move(blob));
case 64:
blob = HeapAsmResourceBlob::allocate(8 * numElements, /*align=*/64,
/*dataIsMutable=*/true);
copyIntAttrIntoBlob<uint64_t>(blob, attr);
return DenseBlobResourceElementsAttr::get(st, "dense_elements_i64",
std::move(blob));
default:
return {};
}
} else if (auto attr = llvm::dyn_cast<DenseFPElementsAttr>(elementsAttr)) {
AsmResourceBlob blob;
switch (bitWidth) {
case 8:
blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64,
/*dataIsMutable=*/true);
copyFPAttrIntoBlob<uint8_t>(blob, attr);
return DenseBlobResourceElementsAttr::get(st, "dense_elements_f8",
std::move(blob));
case 16:
blob = HeapAsmResourceBlob::allocate(2 * numElements, /*align=*/64,
/*dataIsMutable=*/true);
copyFPAttrIntoBlob<uint16_t>(blob, attr);
return DenseBlobResourceElementsAttr::get(st, "dense_elements_f16",
std::move(blob));
case 32:
blob = HeapAsmResourceBlob::allocate(4 * numElements, /*align=*/64,
/*dataIsMutable=*/true);
copyFPAttrIntoBlob<uint32_t>(blob, attr);
return DenseBlobResourceElementsAttr::get(st, "dense_elements_f32",
std::move(blob));
case 64:
blob = HeapAsmResourceBlob::allocate(8 * numElements, /*align=*/64,
/*dataIsMutable=*/true);
copyFPAttrIntoBlob<uint64_t>(blob, attr);
return DenseBlobResourceElementsAttr::get(st, "dense_elements_f64",
std::move(blob));
default:
return {};
}
}
return {};
}
};

} // namespace

std::unique_ptr<OperationPass<void>> createImportResourcesPass() {
return std::make_unique<ImportResourcesPass>();
}

} // namespace mlir::iree_compiler::IREE::Util
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ std::unique_ptr<OperationPass<mlir::ModuleOp>> createPropagateSubrangesPass();
std::unique_ptr<OperationPass<void>> createSimplifyGlobalAccessesPass();
std::unique_ptr<OperationPass<void>> createStripDebugOpsPass();

// Resource Management.
std::unique_ptr<OperationPass<void>> createImportResourcesPass();

// Type conversion.
std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteI64ToI32Pass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteF32ToF16Pass();
Expand Down
21 changes: 21 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ def SimplifyGlobalAccesses :
}];
}

//===----------------------------------------------------------------------===//
// Resource Management
//===----------------------------------------------------------------------===//

def ImportResources : Pass<"iree-util-import-resources", ""> {
let summary = "Imports IR with arbitrary large-data into resources that IREE can manage efficiently";
let description = [{
MLIR has many interesting ways to store large constants, most of which
derive from *ElementsAttr. Given the uniquing/inline behavior, this exacts
very large runtime and memory overhead costs.

This is a temporary pass to convert a majority of the legacy
DenseElementsAttr attributes to DenseResourceElementsAttr. Ideally this
is done at the source (frontend), but this pass is provided to aid
transition and testing by doing a manual conversion with iree-opt.
}];
let constructor = [{
mlir::iree_compiler::IREE::Util::createImportResourcesPass()
}];
}

//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_lit_test_suite(
"fuse_globals.mlir",
"hoist_into_globals.mlir",
"hoist_into_globals_linalg.mlir",
"import_resources.mlir",
"ipo.mlir",
"promote_bf16_to_f32.mlir",
"promote_f16_to_f32.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"fuse_globals.mlir"
"hoist_into_globals.mlir"
"hoist_into_globals_linalg.mlir"
"import_resources.mlir"
"ipo.mlir"
"promote_bf16_to_f32.mlir"
"promote_f16_to_f32.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// RUN: iree-opt --split-input-file --iree-util-import-resources %s | FileCheck %s

// CHECK-LABEL: func.func @constant_splat_i64
func.func @constant_splat_i64() -> tensor<4xi64> {
// Splats should not convert.
// CHECK-NEXT: constant dense<123>
%c123 = arith.constant dense<123> : tensor<4xi64>
return %c123 : tensor<4xi64>
}

// -----
// CHECK-LABEL: func.func @dense_i1
func.func @dense_i1() -> tensor<4xi1> {
// CHECK: dense_resource<dense_elements_i1>
%c123 = arith.constant dense<[true, false, false, true]> : tensor<4xi1>
return %c123 : tensor<4xi1>
}

// CHECK: dense_elements_i1: "0x4000000001000001"

// -----
// CHECK-LABEL: func.func @dense_i8
func.func @dense_i8() -> tensor<4xi8> {
// CHECK: dense_resource<dense_elements_i8>
%c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi8>
return %c123 : tensor<4xi8>
}

// CHECK: dense_elements_i8: "0x400000000102037F"

// -----
// CHECK-LABEL: func.func @dense_i16
func.func @dense_i16() -> tensor<4xi16> {
// CHECK: dense_resource<dense_elements_i16>
%c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi16>
return %c123 : tensor<4xi16>
}

// CHECK: dense_elements_i16: "0x400000000100020003007F00"

// -----
// CHECK-LABEL: func.func @dense_i32
func.func @dense_i32() -> tensor<4xi32> {
// CHECK: dense_resource<dense_elements_i32>
%c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi32>
return %c123 : tensor<4xi32>
}

// CHECK: dense_elements_i32: "0x400000000100000002000000030000007F000000"

// -----
// CHECK-LABEL: func.func @dense_i64
func.func @dense_i64() -> tensor<4xi64> {
// CHECK: dense_resource<dense_elements_i64>
%c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi64>
return %c123 : tensor<4xi64>
}

// CHECK: dense_elements_i64: "0x400000000100000000000000020000000000000003000000000000007F00000000000000"

// -----
// CHECK-LABEL: func.func @dense_f16
func.func @dense_f16() -> tensor<4xf16> {
// CHECK: dense_resource<dense_elements_f16>
%c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf16>
return %c123 : tensor<4xf16>
}

// CHECK: dense_elements_f16: "0x40000000663C66409A420000"

// -----
// CHECK-LABEL: func.func @dense_f32
func.func @dense_f32() -> tensor<4xf32> {
// CHECK: dense_resource<dense_elements_f32>
%c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf32>
return %c123 : tensor<4xf32>
}

// CHECK: dense_elements_f32: "0x40000000CDCC8C3FCDCC0C403333534000000000"

// -----
// CHECK-LABEL: func.func @dense_f64
func.func @dense_f64() -> tensor<4xf64> {
// CHECK: dense_resource<dense_elements_f64>
%c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf64>
return %c123 : tensor<4xf64>
}

// CHECK: dense_elements_f64: "0x400000009A9999999999F13F9A999999999901406666666666660A400000000000000000"

0 comments on commit 1791958

Please sign in to comment.