From 3ca0a49b424c7d9918a2e73d0c19c308e7d8e6db Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 7 May 2024 14:38:36 -0700 Subject: [PATCH] Moving OutlineConstantsPass to flow and adding parameter support. (#17303) This allows us to hide the stream dialect attributes from frontends and use inline flow.tensor.constant ops with parameter attrs. Outlining now also properly preserves hoistable attrs such as stream affinity. By running IPO at the head of the flow pipeline we gain fusion opportunities for hoisted (by user or by global opt) constants and then we clean up the inlined constants at the end of flow so that the stream dialect can handle all values consistently. --- .../iree/compiler/tools/ir_tool/__main__.py | 2 +- .../iree/compiler/Dialect/Flow/IR/FlowBase.td | 3 + .../compiler/Dialect/Flow/IR/FlowTypes.cpp | 17 ++ .../Dialect/Flow/Transforms/BUILD.bazel | 1 + .../Dialect/Flow/Transforms/CMakeLists.txt | 1 + .../Flow/Transforms/ExportBenchmarkFuncs.cpp | 30 ++-- .../Flow/Transforms/OutlineConstants.cpp | 169 ++++++++++++++++++ .../Dialect/Flow/Transforms/Passes.cpp | 71 +++++++- .../Dialect/Flow/Transforms/Passes.td | 14 ++ .../Dialect/Flow/Transforms/test/BUILD.bazel | 1 + .../Flow/Transforms/test/CMakeLists.txt | 1 + .../test/export_benchmark_funcs.mlir | 18 +- .../Transforms/test/outline_constants.mlir | 79 ++++++++ .../Dialect/Stream/Conversion/BUILD.bazel | 1 + .../Dialect/Stream/Conversion/CMakeLists.txt | 1 + .../Conversion/FlowToStream/Patterns.cpp | 6 +- .../FlowToStream/test/tensor_ops.mlir | 2 +- .../Stream/Conversion/PatternUtils.cpp | 12 ++ .../Dialect/Stream/Conversion/PatternUtils.h | 4 + .../StandardToStream/ConvertConstantOps.cpp | 3 +- .../Conversion/UtilToStream/Patterns.cpp | 5 +- .../compiler/Dialect/Stream/IR/StreamOps.td | 3 +- .../Dialect/Stream/Transforms/Passes.cpp | 9 - .../compiler/Dialect/Util/IR/UtilAttrs.cpp | 40 +++++ .../Dialect/Util/IR/UtilInterfaces.td | 11 ++ .../Dialect/Util/Transforms/BUILD.bazel | 1 - .../Dialect/Util/Transforms/CMakeLists.txt | 1 - .../Util/Transforms/HoistIntoGlobals.cpp | 61 ++++--- .../Util/Transforms/OutlineConstants.cpp | 124 ------------- .../compiler/Dialect/Util/Transforms/Passes.h | 1 - .../Dialect/Util/Transforms/Passes.td | 8 - .../Dialect/Util/Transforms/test/BUILD.bazel | 1 - .../Util/Transforms/test/CMakeLists.txt | 1 - .../Transforms/test/hoist_into_globals.mlir | 14 +- .../test/hoist_into_globals_linalg.mlir | 2 +- .../Transforms/test/outline_constants.mlir | 30 ---- .../VM/Transforms/GlobalInitialization.cpp | 2 + .../test/global_initialization.mlir | 37 ++++ .../test/hoist_into_globals.mlir | 10 +- .../IO/Parameters/Transforms/BUILD.bazel | 2 +- .../IO/Parameters/Transforms/CMakeLists.txt | 2 +- .../Transforms/ExportParameters.cpp | 4 +- .../GenerateSplatParameterArchive.cpp | 53 ++++-- .../Transforms/ImportParameters.cpp | 11 +- .../IO/Parameters/Transforms/Passes.td | 4 +- .../Transforms/test/export_parameters.mlir | 20 +-- .../generate_splat_parameter_archive.mlir | 17 +- .../Transforms/test/import_parameters.mlir | 10 +- .../bindings/python/tests/io_runtime_test.py | 8 +- .../parameters/generate_splat_archive.mlir | 8 +- tools/test/parameters_scoped.mlir | 8 +- tools/test/parameters_unscoped.mlir | 8 +- 52 files changed, 636 insertions(+), 316 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir delete mode 100644 compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp delete mode 100644 compiler/src/iree/compiler/Dialect/Util/Transforms/test/outline_constants.mlir diff --git a/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py b/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py index 589b710696df..574e0e5e7aa0 100644 --- a/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py +++ b/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py @@ -113,7 +113,7 @@ def do_strip_data(args) -> int: ): return 1 if not inv.execute_text_pass_pipeline( - "iree-util-outline-constants, iree-util-strip-and-splat-constants" + "iree-flow-outline-constants, iree-util-strip-and-splat-constants" ): return 2 write_output(inv, output, args) diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td index 6528cf55fde6..aac974b16fa1 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td @@ -241,6 +241,9 @@ def FLOW_CollectiveReductionOpAttr : def FLOW_NamedParameterAttr : AttrDef, ]> { let mnemonic = "parameter.named"; let summary = [{named parameter referenced an optional scope and key}]; diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp index 91a970a26c24..58c853fb08cb 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp @@ -310,4 +310,21 @@ void printParameterReference(AsmPrinter &p, StringAttr scopeAttr, p << "\"" << keyAttr.getValue() << "\""; } +//===----------------------------------------------------------------------===// +// #flow.parameter.named<...> +//===----------------------------------------------------------------------===// + +int64_t NamedParameterAttr::getStorageSize() const { + if (auto configAttr = getConfig()) { + if (auto lengthAttr = configAttr.getAs("length")) { + return lengthAttr.getInt(); + } + } + if (auto shapedType = llvm::dyn_cast(getType())) { + return IREE::Util::getRoundedPhysicalStorageSize(shapedType); + } else { + return IREE::Util::getTypePhysicalStorageBitWidth(getType()); + } +} + } // namespace mlir::iree_compiler::IREE::Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel index 1b2339ed0195..2c40f820ac24 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel @@ -57,6 +57,7 @@ iree_compiler_cc_library( "InjectTensorTracing.cpp", "InsertDispatchDebugTargets.cpp", "InterchangeTransposeGenericOps.cpp", + "OutlineConstants.cpp", "OutlineDispatchExterns.cpp", "OutlineDispatchRegions.cpp", "Passes.cpp", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index 7b14e0427f05..007891c828a8 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -57,6 +57,7 @@ iree_cc_library( "InjectTensorTracing.cpp" "InsertDispatchDebugTargets.cpp" "InterchangeTransposeGenericOps.cpp" + "OutlineConstants.cpp" "OutlineDispatchExterns.cpp" "OutlineDispatchRegions.cpp" "Passes.cpp" diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp index 1980b3d9ba86..5d6c3fef4137 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp @@ -178,19 +178,25 @@ static IREE::Util::GlobalOp createDummyInput(const std::string &namePrefix, OpBuilder &moduleBuilder, Explorer &explorer) { std::string name = namePrefix + "_arg" + std::to_string(arg.getArgNumber()); - return TypeSwitch(arg.getType()) - .Case([&](IREE::HAL::BufferViewType type) { - return createImportBufferViewGlobalOp(name, arg, symbolTable, + auto globalOp = + TypeSwitch(arg.getType()) + .Case([&](IREE::HAL::BufferViewType type) { + return createImportBufferViewGlobalOp(name, arg, symbolTable, + moduleBuilder, explorer); + }) + .Case([&](IREE::HAL::BufferType type) { + return createExportBufferGlobalOp(name, arg, symbolTable, moduleBuilder, explorer); - }) - .Case([&](IREE::HAL::BufferType type) { - return createExportBufferGlobalOp(name, arg, symbolTable, moduleBuilder, - explorer); - }) - .Default([&](Type type) { - return createPrimitiveDefaultGlobalOp(name, arg.getLoc(), type, - symbolTable, moduleBuilder); - }); + }) + .Default([&](Type type) { + return createPrimitiveDefaultGlobalOp(name, arg.getLoc(), type, + symbolTable, moduleBuilder); + }); + if (globalOp) { + // Prevent globals from folding so that we have unique buffers for each arg. + globalOp->setAttr("flow.unique_id", moduleBuilder.getStringAttr(name)); + } + return globalOp; } static LogicalResult diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp new file mode 100644 index 000000000000..0e1562bc3aca --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp @@ -0,0 +1,169 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Utils/StringUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::Flow { + +#define GEN_PASS_DEF_OUTLINECONSTANTSPASS +#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc" + +namespace { + +// Returns true if |value| is worth outlining (large, etc). +static bool isOutlinableValue(Attribute value) { + if (auto elementsAttr = llvm::dyn_cast(value)) { + // Don't outline splats - we want those fused. + return !elementsAttr.isSplat(); + } else if (isa(value)) { + // Always outline parameter constants. + return true; + } + return false; +} + +struct ConstantDef { + Operation *op; + Type type; + TypedAttr value; +}; + +// Returns a list of all constant-like shaped data ops in the module. +static SmallVector findConstantsInModule(mlir::ModuleOp moduleOp) { + SmallVector results; + for (auto callableOp : moduleOp.getOps()) { + auto *region = callableOp.getCallableRegion(); + if (!region) + continue; + region->walk([&](Operation *op) { + if (auto constantOp = dyn_cast(op)) { + if (isOutlinableValue(constantOp.getValue())) { + results.push_back(ConstantDef{ + constantOp, + constantOp.getType(), + constantOp.getValue(), + }); + } + } else if (auto constantOp = dyn_cast(op)) { + if (isOutlinableValue(constantOp.getValue())) { + results.push_back(ConstantDef{ + constantOp, + constantOp.getType(), + constantOp.getValue(), + }); + } + } + }); + } + return results; +} + +// Returns the operation containing |childOp| that is a direct child of +// |ancestorOp|. May return |childOp|. +static Operation *getParentInOp(Operation *childOp, Operation *ancestorOp) { + assert(childOp != ancestorOp && "child can't be its own ancestor"); + do { + auto *parentOp = childOp->getParentOp(); + if (parentOp == ancestorOp) + return childOp; + childOp = parentOp; + } while (childOp); + assert(false && "child must be nested under ancestor"); + return nullptr; +} + +static std::string getConstantName(ConstantDef &def) { + std::string str; + llvm::raw_string_ostream os(str); + if (auto parameterAttr = + dyn_cast(def.value)) { + os << "__parameter_"; + if (parameterAttr.getScope() && !parameterAttr.getScope().empty()) + os << parameterAttr.getScope().getValue() << "_"; + os << parameterAttr.getKey().getValue() << "_"; + } else { + os << "__constant_"; + } + def.type.print(os); + str = sanitizeSymbolName(str); + if (str.substr(str.size() - 1) == "_") + str = str.substr(0, str.size() - 1); // strip trailing _ + return str; +} + +//===----------------------------------------------------------------------===// +// --iree-flow-outline-constants +//===----------------------------------------------------------------------===// + +struct OutlineConstantsPass + : public IREE::Flow::impl::OutlineConstantsPassBase { + void runOnOperation() override { + auto moduleOp = getOperation(); + if (moduleOp.getBody()->empty()) + return; + + SymbolTable moduleSymbols(moduleOp); + + // Create all top-level util.globals from constants in the module. + std::vector> replacements; + for (auto &def : findConstantsInModule(moduleOp)) { + // Position the global immediately preceding the top-level op that + // contains the constant. + OpBuilder moduleBuilder(&moduleOp.getBody()->front()); + auto parentFuncOp = getParentInOp(def.op, moduleOp); + if (parentFuncOp) + moduleBuilder.setInsertionPoint(parentFuncOp); + + // New immutable global takes the constant attribute in its specified + // encoding. + auto globalOp = moduleBuilder.create( + def.op->getLoc(), getConstantName(def), /*isMutable=*/false, def.type, + def.value); + globalOp.setPrivate(); + IREE::Util::HoistableAttrInterface::gatherHoistableAttrs(def.op, + globalOp); + moduleSymbols.insert(globalOp); // uniques name + replacements.emplace_back(def.op, globalOp); + + // Prevent the variable from being re-inlined if the canonicalizer runs. + // By the time we've outlined things here we are sure we want them + // outlined even if the user runs an arbitrary number of passes between + // now and when we may use that information (HAL constant pooling, etc). + globalOp.setInliningPolicyAttr( + moduleBuilder.getAttr()); + } + + // Replace all of the constants with lookups for the new variables. + for (auto pair : replacements) { + auto *originalOp = pair.first; + auto globalOp = pair.second; + OpBuilder builder(moduleOp.getContext()); + builder.setInsertionPoint(originalOp); + auto loadOp = globalOp.createLoadOp(originalOp->getLoc(), builder); + loadOp.setGlobalImmutable(true); + originalOp->getResult(0).replaceAllUsesWith( + loadOp.getLoadedGlobalValue()); + originalOp->erase(); + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 941d66e2fa4e..698629ceeb8d 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -121,6 +121,36 @@ namespace mlir::iree_compiler::IREE::Flow { using FunctionLikeNest = MultiOpNest; +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +static void addCleanupPatterns(OpPassManager &passManager) { + FunctionLikeNest(passManager) + // Standard MLIR cleanup. + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass) + + // Simplify util.global accesses; this can help with data flow tracking as + // redundant store-loads are removed. + .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + + // Cleanup and canonicalization of util.global (and other util ops). + passManager.addPass(IREE::Util::createApplyPatternsPass()); + passManager.addPass(IREE::Util::createFoldGlobalsPass()); + passManager.addPass(IREE::Util::createFuseGlobalsPass()); + + // Large IPO pass. Note that this can introduce a significant amount of + // duplication/inlined constants and we'll want to ensure we're running + // cleanup again after (this entire set of patterns is run in a fixed-point + // iteration to do that). + passManager.addPass(IREE::Util::createIPOPass()); +} + +//===----------------------------------------------------------------------===// +// Pipelines +//===----------------------------------------------------------------------===// + void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) { // 1. Do some simple elementwise op fusion. This could be skipped, // but could reduce the surface area of ops to handle later. @@ -240,6 +270,22 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager, clEnableFusePaddingIntoLinalgConsumerOps})); } + { + // We run these under a fixed-point iteration such that we can perform + // inter-procedural, intra-procedural, and canonicalization as separably + // verifiable/reusable passes. IPO will fold duplicate arguments/results + // and inline constants to allow the local optimizations to work more + // effectively. + OpPassManager ipoPipeline(mlir::ModuleOp::getOperationName()); + + // IPO and other cleanups. + addCleanupPatterns(ipoPipeline); + + // Run fixed-point iteration on the IPO pipeline. + passManager.addPass( + IREE::Util::createFixedPointIteratorPass(std::move(ipoPipeline))); + } + addDispatchRegionCreationPasses(passManager, transformOptions); FunctionLikeNest(passManager) @@ -325,9 +371,28 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager, // passes above after we've formed dispatch regions. .addPass(IREE::Flow::createInjectTensorTracingPass) // Cleanup the IR after we are done. - .addPass(IREE::Flow::createCleanupTensorShapesPass) - .addPass(mlir::createCanonicalizerPass) - .addPass(mlir::createCSEPass); + .addPass(IREE::Flow::createCleanupTensorShapesPass); + + { + // We run these under a fixed-point iteration such that we can perform + // inter-procedural, intra-procedural, and canonicalization as separably + // verifiable/reusable passes. IPO will fold duplicate arguments/results + // and inline constants to allow the local optimizations to work more + // effectively. + OpPassManager ipoPipeline(mlir::ModuleOp::getOperationName()); + + // Turn all constant ops into global variables and fix up the IR. + // As many locations change and constants are deduplicated we'll end up with + // a lot of extraneous IR (mostly global loads) and clean those up here. + ipoPipeline.addPass(IREE::Flow::createOutlineConstantsPass()); + + // IPO and other cleanups. + addCleanupPatterns(ipoPipeline); + + // Run fixed-point iteration on the IPO pipeline. + passManager.addPass( + IREE::Util::createFixedPointIteratorPass(std::move(ipoPipeline))); + } // Cleanup executable contents. { diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td index 2f3360a01d29..592d015e9dec 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td @@ -391,6 +391,20 @@ def InterchangeTransposeGenericOpsPass : ]; } +def OutlineConstantsPass : + Pass<"iree-flow-outline-constants", "mlir::ModuleOp"> { + let summary = "Outlines tensor constants into util.globals at the module level."; + let description = [{ + Outlines tensor constants throughout the program into globals initialized + with stream ops. + }]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "IREE::Flow::FlowDialect", + "IREE::Util::UtilDialect", + ]; +} + def OutlineDispatchExternsPass : Pass<"iree-flow-outline-dispatch-externs", "mlir::ModuleOp"> { let summary = "Outlines external dispatches into executables."; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel index bb2326656bcf..b9ce61bf48ae 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel @@ -42,6 +42,7 @@ iree_lit_test_suite( "inject_tensor_tracing.mlir", "insert_dispatch_debug_targets.mlir", "interchange_transpose_generic_ops.mlir", + "outline_constants.mlir", "outline_dispatch_externs.mlir", "outline_dispatch_regions.mlir", "pad_fusion_with_consumer.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index 4fda8735aed5..e7df9deea7f3 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt @@ -40,6 +40,7 @@ iree_lit_test_suite( "inject_tensor_tracing.mlir" "insert_dispatch_debug_targets.mlir" "interchange_transpose_generic_ops.mlir" + "outline_constants.mlir" "outline_dispatch_externs.mlir" "outline_dispatch_regions.mlir" "pad_fusion_with_consumer.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir index dfb93d5706fe..03f5b56fc369 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir @@ -11,8 +11,8 @@ util.func public @simpleMul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> util.return %3 : !hal.buffer_view } -// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {inlining_policy = #util.inline.never} : !hal.buffer_view -// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {inlining_policy = #util.inline.never} : !hal.buffer_view +// CHECK: util.global private @[[GLOBAL_ARG0:.+]] { +// CHECK: util.global private @[[GLOBAL_ARG1:.+]] { // CHECK: util.func public @simpleMul_benchmark() attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "entry"}} { // CHECK-DAG: %[[ARG0:.+]] = util.global.load @[[GLOBAL_ARG0]] : !hal.buffer_view @@ -37,12 +37,12 @@ util.func public @while(%start: i32, %bound: i32) -> i32 { util.return %5 : i32 } -// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {inlining_policy = #util.inline.never} = 0 : i32 -// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {inlining_policy = #util.inline.never} = 0 : i32 +// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {{{.+}}} = 0 : i32 +// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {{{.+}}} = 0 : i32 // CHECK: util.func public @while_benchmark() -// CHECK-DAG: %[[ARG0:.+]] = util.global.load @[[GLOBAL_ARG0]] : i32 -// CHECK-DAG: %[[ARG1:.+]] = util.global.load @[[GLOBAL_ARG1]] : i32 +// CHECK-DAG: %[[ARG0:.+]] = util.global.load immutable @[[GLOBAL_ARG0]] : i32 +// CHECK-DAG: %[[ARG1:.+]] = util.global.load immutable @[[GLOBAL_ARG1]] : i32 // CHECK: %[[RET0:.+]] = util.call @while(%[[ARG0]], %[[ARG1]]) // CHECK: util.optimization_barrier %[[RET0]] : i32 // CHECK: util.return @@ -59,7 +59,7 @@ util.func public @importBufferViewBitcasting(%view: !hal.buffer_view) -> !hal.bu util.return %2 : !hal.buffer_view } -// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {inlining_policy = #util.inline.never} : !hal.buffer_view +// CHECK: util.global private @[[GLOBAL_ARG0:.+]] { // CHECK: util.initializer { // CHECK-DAG: %[[SPLAT:.+]] = flow.tensor.splat %c0_i32 // CHECK-DAG: %[[EXPORT:.+]] = hal.tensor.export %[[SPLAT]] : tensor<4xi32> -> !hal.buffer_view @@ -99,14 +99,14 @@ util.func public @exportBufferViewInPlace(%view: !hal.buffer_view, %storage: !ha util.return %2 : !hal.buffer_view } -// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {inlining_policy = #util.inline.never} : !hal.buffer_view +// CHECK: util.global private @[[GLOBAL_ARG0:.+]] { // CHECK: util.initializer { // CHECK-DAG: %[[SPLAT0:.+]] = flow.tensor.splat %c0_i32 // CHECK-DAG: %[[EXPORT0:.+]] = hal.tensor.export %[[SPLAT0]] : tensor<4xi32> -> !hal.buffer_view // CHECK-DAG: %[[DNO0:.+]] = util.optimization_barrier %[[EXPORT0]] // CHECK-NEXT: util.global.store %[[DNO0]], @[[GLOBAL_ARG0]] -// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {inlining_policy = #util.inline.never} : !hal.buffer +// CHECK: util.global private @[[GLOBAL_ARG1:.+]] { // CHECK: util.initializer { // CHECK-DAG: %[[SPLAT1:.+]] = flow.tensor.splat %c0_i32 // CHECK-DAG: %[[EXPORT1:.+]] = hal.tensor.export %[[SPLAT1]] : tensor<4xi32> -> !hal.buffer diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir new file mode 100644 index 000000000000..e3db1b69795d --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir @@ -0,0 +1,79 @@ +// RUN: iree-opt --split-input-file --iree-flow-outline-constants %s | FileCheck %s + +// Tests that we don't outline splats (as we want them to be transients). + +// CHECK-LABEL: @splatConstant +util.func @splatConstant() { + // CHECK-DAG: = arith.constant dense<1> : tensor<512x128xi32> + %arith_cst = arith.constant dense<1> : tensor<512x128xi32> + // CHECK-DAG: = flow.tensor.constant dense<1> : tensor<512x128xi32> + %flow_cst = flow.tensor.constant dense<1> : tensor<512x128xi32> + util.return +} + +// ----- + +// Tests that constant parameters are outlined. + +// CHECK: util.global private @__parameter_scope_key_tensor_4x2xi32 {inlining_policy = #util.inline.never} = #flow.parameter.named<"scope"::"key"> : tensor<4x2xi32> +// CHECK-LABEL: @parameterConstant +util.func @parameterConstant() { + // CHECK: = util.global.load immutable @__parameter_scope_key_tensor_4x2xi32 : tensor<4x2xi32> + %cst = flow.tensor.constant #flow.parameter.named<"scope"::"key"> : tensor<4x2xi32> + util.return +} + +// ----- + +// Tests that multiple constants will be hoisted and named uniquely. + +// CHECK: util.global private @__constant_tensor_2xf32 {inlining_policy = #util.inline.never} = dense<[0.0287729427, 0.0297581609]> : tensor<2xf32> +// CHECK-NEXT: util.global private @__constant_tensor_2xf32_0 {inlining_policy = #util.inline.never} = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32> +// CHECK-NEXT: util.func private @denseConstants +util.func private @denseConstants() { + // CHECK-NEXT: = util.global.load immutable @__constant_tensor_2xf32 : tensor<2xf32> + %cst_0 = arith.constant dense<[0.0287729427, 0.0297581609]> : tensor<2xf32> + // CHECK-NEXT: = util.global.load immutable @__constant_tensor_2xf32_0 : tensor<2xf32> + %cst_1 = flow.tensor.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32> + util.return +} + +// ----- + +// Tests that constants are outlined to the module scope above their use to +// preserve ordering of existing functions/globals. + +// CHECK: util.func private @external_func +util.func private @external_func() +// CHECK-NEXT: util.global private @__constant_tensor_2xi32 +// CHECK-NEXT: util.func private @func_0() +util.func private @func_0() { + // CHECK-NEXT: = util.global.load immutable @__constant_tensor_2xi32 + %cst_0 = arith.constant dense<[0, 1]> : tensor<2xi32> + util.return +} + +// CHECK: util.global private @existing_global +util.global private @existing_global : tensor<4xf32> +// CHECK-NEXT: util.global private @__constant_tensor_3xi32 +// CHECK-NEXT: util.func private @func_1() +util.func private @func_1() { + // CHECK-NEXT: = util.global.load immutable @__constant_tensor_3xi32 + %cst_1 = arith.constant dense<[2, 3, 4]> : tensor<3xi32> + util.return +} + +// ----- + +// Tests that any hoistable attrs are propagated to the outlined globals. + +// CHECK: util.global private @__constant_tensor_2xi32 +// CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]> +// CHECK-NEXT: util.func private @set_affinity +util.func private @set_affinity() attributes { + stream.affinity = #hal.affinity.queue<[0]> +} { + // CHECK-NEXT: = util.global.load immutable @__constant_tensor_2xi32 + %cst = arith.constant dense<[0, 1]> : tensor<2xi32> + util.return +} diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel index 3116a862d1bc..fbc0e51d4463 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel @@ -21,6 +21,7 @@ iree_compiler_cc_library( "PatternUtils.h", ], deps = [ + "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/Stream/IR", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt index 472fdb95bb93..05bbb79469cf 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt @@ -23,6 +23,7 @@ iree_cc_library( MLIRIR MLIRTransformUtils MLIRTransforms + iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::Stream::IR PUBLIC ) diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index 93f1aef6fde8..0942148c0e7c 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -46,7 +46,8 @@ struct ConvertTensorConstantOp getContext(), IREE::Stream::Lifetime::Constant); auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp); auto newOp = rewriter.create( - constantOp.getLoc(), constantType, constantOp.getValue(), + constantOp.getLoc(), constantType, + convertAttributeToStream(constantOp.getValue()), TypeAttr::get(constantOp.getType()), ValueRange{}, affinityAttr); // Transfer to unknown lifetime. @@ -94,7 +95,8 @@ struct ConvertTensorDynamicConstantOp getContext(), IREE::Stream::Lifetime::Constant); auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp); auto newOp = rewriter.create( - constantOp.getLoc(), constantType, constantOp.getValue(), + constantOp.getLoc(), constantType, + convertAttributeToStream(constantOp.getValue()), TypeAttr::get(resultType), dynamicDims, affinityAttr); // Transfer to unknown lifetime. diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir index 1eb0dec28be8..9a1272fdf4c6 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir @@ -17,7 +17,7 @@ util.func public @tensorConstantParameter() -> tensor<4x2xi32> { // CHECK-DAG: %[[CST:.+]] = stream.tensor.constant : tensor<4x2xi32> in !stream.resource = #stream.parameter.named<"scope"::"key"> : tensor<4x2xi32> // CHECK-DAG: %[[SIZE:.+]] = stream.resource.size %[[CST]] : !stream.resource // CHECK-DAG: %[[TRANSFER:.+]] = stream.async.transfer %[[CST]] : !stream.resource{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]} - %cst = flow.tensor.constant #stream.parameter.named<"scope"::"key"> : tensor<4x2xi32> + %cst = flow.tensor.constant #flow.parameter.named<"scope"::"key"> : tensor<4x2xi32> // CHECK: util.return %[[TRANSFER]], %[[SIZE]] util.return %cst : tensor<4x2xi32> } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp index 46c1c836f9e5..6bb26f1644e1 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp @@ -6,11 +6,23 @@ #include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" +#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" namespace mlir::iree_compiler { +TypedAttr convertAttributeToStream(TypedAttr attr) { + if (!attr) + return {}; + if (auto parameterAttr = dyn_cast(attr)) { + return IREE::Stream::NamedParameterAttr::get( + attr.getContext(), parameterAttr.getType(), parameterAttr.getScope(), + parameterAttr.getKey(), parameterAttr.getConfig()); + } + return attr; +} + void expandResourceOperand(Location loc, Value operand, SmallVectorImpl &newOperands, OpBuilder &builder) { diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h index a7a864f8ad93..fd9249e0801e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h @@ -13,6 +13,10 @@ namespace mlir::iree_compiler { +// Converts a supported attribute type to the corresponding stream dialect +// value. Returns the provided value if it is natively supported. +TypedAttr convertAttributeToStream(TypedAttr attr); + void expandResourceOperand(Location loc, Value operand, SmallVectorImpl &newOperands, OpBuilder &builder); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp index 127bd74f27ee..5ff99f7f84ab 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp @@ -4,6 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" #include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h" #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" @@ -34,7 +35,7 @@ struct ConvertTensorConstantOp : public OpConversionPattern { auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp); auto newOp = rewriter.create( constantOp.getLoc(), constantType, - llvm::cast(constantOp.getValue()), + convertAttributeToStream(constantOp.getValue()), TypeAttr::get(constantOp.getType()), /*result_encoding_dims=*/ValueRange{}, affinityAttr); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp index a8e8d078a71a..4d1fa5f8a677 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp @@ -249,7 +249,8 @@ struct GlobalOpExpansion affinityAttr); } else { initialValue = rewriter.create( - globalOp.getLoc(), resourceOp.getType(), initialValueAttr, + globalOp.getLoc(), resourceOp.getType(), + convertAttributeToStream(initialValueAttr), TypeAttr::get(globalOp.getType()), /*result_encoding_dims=*/ValueRange{}, affinityAttr); initialValueSize = rewriter.create( @@ -404,7 +405,7 @@ void populateUtilToStreamConversionPatterns(MLIRContext *context, [&](IREE::Util::GlobalOp op) { return typeConverter.isLegal(op.getType()) && (!op.getInitialValueAttr() || - !llvm::isa(op.getInitialValueAttr().getType())); + !isExpandedType(op.getInitialValueAttr().getType())); }); conversionTarget.addDynamicallyLegalOp( [&](IREE::Util::GlobalAddressOp op) { diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index 1ee27c093e94..c994e65fb587 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -11,6 +11,7 @@ include "iree/compiler/Dialect/Stream/IR/StreamBase.td" include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.td" include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" @@ -1323,7 +1324,7 @@ def Stream_TensorConstantOp : Stream_PureOp<"tensor.constant", [ }]; let arguments = (ins - AnyAttr:$value, + TypedAttrInterface:$value, TypeAttr:$result_encoding, Stream_ShapeDynamicDims:$result_encoding_dims, OptionalAttr:$affinity diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index 5800ded60697..a0861e7c7fd6 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -64,15 +64,6 @@ void buildStreamTensorPassPipeline(OpPassManager &passManager, // propagation or fusion that needs to happen first. addCleanupPatterns(passManager); - // Turn all constant ops into global variables and fix up the IR. - // As many locations change and constants are deduplicated we'll end up with - // a lot of extraneous IR (mostly global loads) and clean those up here. - passManager.addPass(IREE::Util::createOutlineConstantsPass()); - - // Perform cleanup after constant simplification as more canonicalizers may be - // able to kick in. - addCleanupPatterns(passManager); - //---------------------------------------------------------------------------- // Conversion //---------------------------------------------------------------------------- diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp index 513561014e27..03cc0f4f1b32 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp @@ -886,6 +886,46 @@ struct SerializableStringAttrModel } }; +//===----------------------------------------------------------------------===// +// IREE::Util::Hoistable*Interface +//===----------------------------------------------------------------------===// + +// Walks |fromOp| and up to gather all dialect attributes that want to be +// hoisted along with it. If the same named attribute is present on multiple +// ancestors only the most narrowly scoped value will be used. +// static +void HoistableAttrInterface::gatherHoistableAttrs(Operation *fromOp, + NamedAttrList &dialectAttrs) { + for (auto attr : fromOp->getDialectAttrs()) { + if (auto hoistableAttr = llvm::dyn_cast( + attr.getValue())) { + if (hoistableAttr.shouldAttachToHoistedOps() && + !dialectAttrs.get(attr.getName())) { + dialectAttrs.push_back(attr); + } + } + } + if (auto *parentOp = fromOp->getParentOp()) + gatherHoistableAttrs(parentOp, dialectAttrs); +} + +// static +void HoistableAttrInterface::gatherHoistableAttrs(Operation *fromOp, + Operation *toOp) { + // Get the attributes specified on the target op first as those take + // precedence over any from ancestors. We also want to preserve any + // non-hoistable attrs when we reassign the dialect attrs. + NamedAttrList dialectAttrs; + for (auto attr : toOp->getDialectAttrs()) + dialectAttrs.push_back(attr); + + // Gather attributes from the op and its parents, only adding ones not already + // set on the op. + HoistableAttrInterface::gatherHoistableAttrs(fromOp, dialectAttrs); + + toOp->setDialectAttrs(dialectAttrs); +} + //===----------------------------------------------------------------------===// // IREE::Util::UtilDialect //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td index c2155e45571d..eaae0d478c5d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td @@ -1157,6 +1157,17 @@ def Util_HoistableAttrInterface : AttrInterface<"HoistableAttrInterface"> { }] >, ]; + + let extraClassDeclaration = [{ + // Walks |fromOp| and up to gather all dialect attributes that want to be + // hoisted along with it. If the same named attribute is present on multiple + // ancestors only the most narrowly scoped value will be used. + static void gatherHoistableAttrs(Operation *fromOp, + NamedAttrList &dialectAttrs); + + // Copies any hoistable attributes from the source op to the target op. + static void gatherHoistableAttrs(Operation *fromOp, Operation *toOp); + }]; } def Util_HoistableOpInterface : OpInterface<"HoistableOpInterface"> { diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel index 769e6bfde25c..ce557eca3769 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel @@ -27,7 +27,6 @@ iree_compiler_cc_library( "HoistIntoGlobals.cpp", "IPO.cpp", "ImportResources.cpp", - "OutlineConstants.cpp", "PassDetail.h", "Passes.cpp", "Patterns.cpp", diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt index 85dade3fc813..d07b20584309 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt @@ -30,7 +30,6 @@ iree_cc_library( "HoistIntoGlobals.cpp" "IPO.cpp" "ImportResources.cpp" - "OutlineConstants.cpp" "PassDetail.h" "Passes.cpp" "Patterns.cpp" diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp index 2467a5d584ad..aa360cb0944e 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp @@ -10,6 +10,7 @@ #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" #include "iree/compiler/Dialect/Util/Transforms/PassDetail.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" +#include "iree/compiler/Utils/StringUtils.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/Builders.h" @@ -31,22 +32,15 @@ static llvm::cl::opt clPrintDotGraphToFile( // Maps an original value in the program to the symbol name of a global. using HoistedValueMap = llvm::DenseMap; -// Walks |fromOp| and up to gather all dialect attributes that want to be -// hoisted along with it. If the same named attribute is present on multiple -// ancestors only the most narrowly scoped value will be used. -static void gatherHoistableAttrs(Operation *fromOp, - NamedAttrList &dialectAttrs) { - for (auto attr : fromOp->getDialectAttrs()) { - if (auto hoistableAttr = - dyn_cast(attr.getValue())) { - if (hoistableAttr.shouldAttachToHoistedOps() && - !dialectAttrs.get(attr.getName())) { - dialectAttrs.push_back(attr); - } - } - } - if (auto *parentOp = fromOp->getParentOp()) - gatherHoistableAttrs(parentOp, dialectAttrs); +static std::string getHoistedName(Type type) { + std::string str; + llvm::raw_string_ostream os(str); + os << "__hoisted_"; + type.print(os); + str = sanitizeSymbolName(str); + if (str.substr(str.size() - 1) == "_") + str = str.substr(0, str.size() - 1); // strip trailing _ + return str; } // Hoist expressions into globals. It is not expected that such a greedy @@ -130,22 +124,27 @@ class HoistIntoGlobalsPass : public HoistIntoGlobalsBase { OpBuilder builder(&getContext()); for (auto [originalValue, globalOp] : hoistedMap) { builder.setInsertionPointAfterValue(originalValue); - Value load = globalOp.createLoadOp(globalOp->getLoc(), builder) - .getLoadedGlobalValue(); + auto loadOp = globalOp.createLoadOp(globalOp->getLoc(), builder); + if (!originalValue.getDefiningOp() + ->getParentOfType()) { + loadOp.setGlobalImmutable(true); + } + Value loadedValue = loadOp.getLoadedGlobalValue(); // Call user hook to cast back to the original type. if (auto hoistableType = dyn_cast( originalValue.getType())) { - load = hoistableType.decodeStorageType(builder, load.getLoc(), - originalValue.getType(), load); + loadedValue = hoistableType.decodeStorageType( + builder, loadedValue.getLoc(), originalValue.getType(), + loadedValue); } - if (load.getType() != originalValue.getType()) { + if (loadedValue.getType() != originalValue.getType()) { getOperation().emitError() << "Unresolved conflict between casted global of type " - << load.getType() << " and original type " + << loadedValue.getType() << " and original type " << originalValue.getType(); return signalPassFailure(); } - originalValue.replaceAllUsesWith(load); + originalValue.replaceAllUsesWith(loadedValue); } cleanupDeadOps(constExprs); } @@ -168,7 +167,8 @@ class HoistIntoGlobalsPass : public HoistIntoGlobalsBase { // Gather any dialect attributes we may need to preserve. auto *topLevelOp = getTopLevelOp(originalValue.getDefiningOp()); NamedAttrList dialectAttrs; - gatherHoistableAttrs(topLevelOp, dialectAttrs); + IREE::Util::HoistableAttrInterface::gatherHoistableAttrs(topLevelOp, + dialectAttrs); // No existing mapping - create a new global. OpBuilder moduleBuilder(topLevelOp); @@ -269,12 +269,12 @@ class HoistIntoGlobalsPass : public HoistIntoGlobalsBase { // functions for setting the preferred storage type. auto hoistableType = dyn_cast(globalType); - // Get the preferred global storage type. if (hoistableType) { + // Allow the storage type of the global to differ from the local type. globalType = hoistableType.getPreferredStorageType(); } auto globalOp = moduleBuilder.create( - loc, "hoisted", false, globalType); + loc, getHoistedName(globalType), false, globalType); moduleSymbols.insert(globalOp); SymbolTable::setSymbolVisibility(globalOp, SymbolTable::Visibility::Private); @@ -290,17 +290,16 @@ class HoistIntoGlobalsPass : public HoistIntoGlobalsBase { clonedResult.print(llvm::dbgs()); llvm::dbgs() << "\n"; }); - // Cast to the preferred global storage type. if (hoistableType) { + // Allow casting to the global type if it differs from the local type. clonedResult = hoistableType.encodeStorageType( initializerBuilder, clonedResult.getLoc(), globalType, clonedResult); } if (clonedResult.getType() != globalType) { - globalOp.emitError() - << "Unresolved conflict between global of type " << globalType - << " and stored type " << clonedResult.getType(); - return failure(); + return globalOp.emitError() + << "unresolved conflict between global of type " << globalType + << " and stored type " << clonedResult.getType(); } globalOp.createStoreOp(loc, clonedResult, initializerBuilder); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp deleted file mode 100644 index 023031abee3c..000000000000 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include - -#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" -#include "iree/compiler/Dialect/Util/IR/UtilOps.h" -#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h" -#include "iree/compiler/Dialect/Util/Transforms/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/Pass/Pass.h" - -namespace mlir::iree_compiler::IREE::Util { - -// Returns true if |value| is worth outlining (large, etc). -static bool isOutlinableValue(Attribute value) { - if (auto elementsAttr = llvm::dyn_cast(value)) { - // Don't outline splats - we want those fused. - return !elementsAttr.isSplat(); - } - return false; -} - -struct ConstantDef { - Operation *op; - Type type; - ElementsAttr value; -}; - -// Returns a list of all constant-like shaped data ops in the module. -static SmallVector findConstantsInModule(mlir::ModuleOp moduleOp) { - SmallVector results; - for (auto callableOp : moduleOp.getOps()) { - auto *region = callableOp.getCallableRegion(); - if (!region) - continue; - for (auto &block : *region) { - for (auto &op : block.getOperations()) { - if (auto constantOp = dyn_cast(op)) { - if (isOutlinableValue(constantOp.getValue())) { - results.push_back(ConstantDef{ - constantOp, - constantOp.getType(), - llvm::cast(constantOp.getValue()), - }); - } - } - } - } - } - return results; -} - -class OutlineConstantsPass : public OutlineConstantsBase { -public: - OutlineConstantsPass() = default; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - } - - void runOnOperation() override { - auto moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) - return; - - SymbolTable moduleSymbols(moduleOp); - std::string baseName = "_constant"; - - // Create all top-level util.globals from constants in the module. - OpBuilder moduleBuilder(&moduleOp.getBody()->front()); - std::vector> replacements; - for (auto &def : findConstantsInModule(moduleOp)) { - // New immutable global takes the constant attribute in its specified - // encoding. - auto globalOp = moduleBuilder.create( - def.op->getLoc(), baseName, /*isMutable=*/false, def.type, def.value); - globalOp.setPrivate(); - moduleSymbols.insert(globalOp); // uniques name - replacements.emplace_back(def.op, globalOp); - - // Prevent the variable from being re-inlined if the canonicalizer runs. - // By the time we've outlined things here we are sure we want them - // outlined even if the user runs an arbitrary number of passes between - // now and when we may use that information (HAL constant pooling, etc). - globalOp.setInliningPolicyAttr( - moduleBuilder.getAttr()); - } - - // Replace all of the constants with lookups for the new variables. - for (auto pair : replacements) { - auto *originalOp = pair.first; - auto globalOp = pair.second; - OpBuilder builder(moduleOp.getContext()); - builder.setInsertionPoint(originalOp); - auto loadOp = globalOp.createLoadOp(originalOp->getLoc(), builder); - - Value replacement; - if (auto constantOp = dyn_cast(originalOp)) { - // Directly replace constant with global constant value. - replacement = loadOp.getLoadedGlobalValue(); - } else { - assert(false && "unhandled constant op type"); - } - - originalOp->getResult(0).replaceAllUsesWith(replacement); - originalOp->erase(); - } - } -}; - -std::unique_ptr> createOutlineConstantsPass() { - return std::make_unique(); -} - -} // namespace mlir::iree_compiler::IREE::Util diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h index 402ee2ee3a38..a2aa22632dd2 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h @@ -30,7 +30,6 @@ createFixedPointIteratorPass(OpPassManager pipeline); std::unique_ptr> createFoldGlobalsPass(); std::unique_ptr> createFuseGlobalsPass(); std::unique_ptr> createIPOPass(); -std::unique_ptr> createOutlineConstantsPass(); std::unique_ptr> createPropagateSubrangesPass(); std::unique_ptr> createSimplifyGlobalAccessesPass(); std::unique_ptr> diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td index 191da7d11317..f3072d174a4c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td @@ -65,14 +65,6 @@ def IPO : Pass<"iree-util-ipo", "mlir::ModuleOp"> { }]; } -def OutlineConstants : - Pass<"iree-util-outline-constants", "mlir::ModuleOp"> { - let summary = "Outlines tensor constants into util.globals at the module level."; - let constructor = [{ - mlir::iree_compiler::IREE::Util::createOutlineConstantsPass() - }]; -} - def PropagateSubranges : Pass<"iree-util-propagate-subranges", "mlir::ModuleOp"> { let summary = "Propagates resource subranges across the program."; let constructor = [{ diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel index df0835a9e207..6c608ccbf971 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel @@ -28,7 +28,6 @@ iree_lit_test_suite( "hoist_into_globals_linalg.mlir", "import_resources.mlir", "ipo.mlir", - "outline_constants.mlir", "patterns.mlir", "promote_bf16_to_f32.mlir", "promote_f16_to_f32.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt index 53d73ec9120d..2ed4d402aa83 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt @@ -26,7 +26,6 @@ iree_lit_test_suite( "hoist_into_globals_linalg.mlir" "import_resources.mlir" "ipo.mlir" - "outline_constants.mlir" "patterns.mlir" "promote_bf16_to_f32.mlir" "promote_f16_to_f32.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir index 28a7c205a463..37f06055d490 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir @@ -16,7 +16,7 @@ module @hoist_simple_const_expr { %1 = arith.constant 1 : i32 // CHECK-NOT: arith.constant // CHECK-NOT: iree_unregistered.const_expr - // CHECK: %[[VAL:.*]] = util.global.load @[[HOISTED_SYM]] : i32 + // CHECK: %[[VAL:.*]] = util.global.load immutable @[[HOISTED_SYM]] : i32 // CHECK: util.return %[[VAL]] %2 = "iree_unregistered.const_expr"(%0, %1) : (i32, i32) -> i32 util.return %2 : i32 @@ -141,8 +141,8 @@ module @hoist_tree_const_expr { // CHECK: util.func public @main util.func public @main() -> (i32, i32, i32) { - // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load @[[HOISTED_0]] : i32 - // CHECK-DAG: %[[LOAD_HOISTED_1:.*]] = util.global.load @[[HOISTED_1]] : i32 + // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load immutable @[[HOISTED_0]] : i32 + // CHECK-DAG: %[[LOAD_HOISTED_1:.*]] = util.global.load immutable @[[HOISTED_1]] : i32 // CHECK-DAG: %[[RESULT:.*]] = "iree_unregistered.var_expr"(%[[LOAD_HOISTED_1]]) // CHECK: util.return %[[LOAD_HOISTED_0]], %[[LOAD_HOISTED_1]], %[[RESULT]] %0 = arith.constant 0 : i32 @@ -171,7 +171,7 @@ module @hoist_const_expr_with_ineligible_consumer { // CHECK: } // CHECK: util.func public @main util.func public @main() -> i32 { - // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load @[[HOISTED_0]] : i32 + // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load immutable @[[HOISTED_0]] : i32 // CHECK-DAG: %[[RESULT:.*]] = "iree_unregistered.var_expr"(%[[LOAD_HOISTED_0]]) // CHECK: util.return %[[RESULT]] %0 = arith.constant 0 : i32 @@ -201,8 +201,8 @@ module @hoist_non_leaf_const_expr { // CHECK: } // CHECK: util.func public @main util.func public @main() -> (i32) { - // CHECK: %[[LOAD_HOISTED:.*]] = util.global.load @[[HOISTED]] : i32 - // CHECK: %[[RESULT:.*]] = "iree_unregistered.non_leaf_const_expr"(%hoisted) + // CHECK: %[[LOAD_HOISTED:.*]] = util.global.load immutable @[[HOISTED]] : i32 + // CHECK: %[[RESULT:.*]] = "iree_unregistered.non_leaf_const_expr"(%[[LOAD_HOISTED]]) // CHECK: util.return %[[RESULT]] %0 = arith.constant 0 : i32 %1 = arith.constant 1 : i32 @@ -236,7 +236,7 @@ module @hoist_implicit_capture { %1 = arith.constant 1 : i32 // CHECK-NOT: arith.constant // CHECK-NOT: iree_unregistered.const_expr - // CHECK: %[[VAL:.*]] = util.global.load @[[HOISTED_SYM]] : i32 + // CHECK: %[[VAL:.*]] = util.global.load immutable @[[HOISTED_SYM]] : i32 // CHECK: util.return %[[VAL]] %2 = "iree_unregistered.const_expr"(%0) ({ ^bb0(%inner0 : i32): diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals_linalg.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals_linalg.mlir index cb41b5ebacd2..1760ff2b2717 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals_linalg.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals_linalg.mlir @@ -27,7 +27,7 @@ module @compute_hoisted { linalg.yield %42 : f32 } -> tensor<5x6xf32> - // CHECK: %[[RESULT:.*]] = util.global.load @[[HOISTED]] : tensor<5x6xf32> + // CHECK: %[[RESULT:.*]] = util.global.load immutable @[[HOISTED]] : tensor<5x6xf32> // CHECK: util.return %[[RESULT]] util.return %3 : tensor<5x6xf32> } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/outline_constants.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/outline_constants.mlir deleted file mode 100644 index 76b27dabd19b..000000000000 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/outline_constants.mlir +++ /dev/null @@ -1,30 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-util-outline-constants %s | FileCheck %s - -// CHECK-LABEL: @scalarConstant -util.func @scalarConstant() { - // CHECK: = arith.constant 0 : i32 - %cst = arith.constant 0 : i32 - util.return -} - -// ----- - -// CHECK-LABEL: @splatConstant -util.func @splatConstant() { - // CHECK: = arith.constant dense<1.200000e+00> : tensor<512x128xf32> - %cst = arith.constant dense<1.2> : tensor<512x128xf32> - util.return -} - -// ----- - -// CHECK: util.global private @_constant {inlining_policy = #util.inline.never} = dense<[0.0287729427, 0.0297581609]> : tensor<2xf32> -// CHECK-NEXT: util.global private @_constant_0 {inlining_policy = #util.inline.never} = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]> : tensor<8xf32> -// CHECK-LABEL: @denseConstants -util.func @denseConstants() { - // CHECK: = util.global.load @_constant : tensor<2xf32> - %cst_0 = arith.constant dense<[0.0287729427, 0.0297581609]> : tensor<2xf32> - // CHECK-NEXT: = util.global.load @_constant_0 : tensor<8xf32> - %cst_1 = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]> : tensor<8xf32> - util.return -} diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp index 61d3b6cfe549..65ad8c9c2032 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp @@ -100,6 +100,8 @@ static void fixupGlobalMutability(Operation *moduleOp, // If there are stores mark the global as mutable. globalInfo->op.setGlobalMutable(!globalInfo->getStores().empty()); } + for (auto loadOp : globalInfo->getLoads()) + loadOp.setGlobalImmutable(!globalInfo->op.isGlobalMutable()); }); for (auto *deadOp : deadOps) deadOp->erase(); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir index 757fef225a50..c7930a5c9a2d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir @@ -35,6 +35,43 @@ vm.module @init_i32 { // ----- +// CHECK-LABEL: @mutability_change +vm.module @mutability_change { + // CHECK: vm.global.i32 private @g0 + vm.global.i32 private mutable @g0 : i32 = 0 : i32 + // CHECK: vm.global.i32 private mutable @g1 : i32 + vm.global.i32 private mutable @g1 = 123 : i32 + // CHECK: vm.global.i32 private mutable @g2 : i32 + vm.global.i32 private @g2 : i32 + + vm.initializer { + %c456 = vm.const.i32 456 + vm.global.store.i32 %c456, @g2 : i32 + vm.return + } + + // CHECK: vm.func public @func + vm.func public @func() { + // CHECK: vm.global.load.i32 immutable @g0 + vm.global.load.i32 @g0 : i32 + // CHECK: vm.global.load.i32 @g1 + vm.global.load.i32 @g1 : i32 + // CHECK: vm.global.load.i32 @g2 + vm.global.load.i32 immutable @g2 : i32 + vm.return + } + + // CHECK: vm.func private @__init() { + // CHECK-NEXT: %c123 = vm.const.i32 123 + // CHECK-NEXT: vm.global.store.i32 %c123, @g1 + // CHECK-NEXT: %c456 = vm.const.i32 456 + // CHECK-NEXT: vm.global.store.i32 %c456, @g2 + // CHECK-NEXT: vm.return + // CHECK-NEXT: } +} + +// ----- + // CHECK-LABEL: @init_ref vm.module @init_ref { // CHECK: vm.global.ref private mutable @g0 : !vm.ref diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir index 46705a954f9e..4082bbfa4721 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir @@ -10,7 +10,7 @@ module @hoist_sub_byte_tensor_store { // CHECK: util.return // CHECK: util.func public @main() -> tensor<64xi4> - // CHECK: %[[GLOBAL_LD:.+]] = util.global.load @{{.*}} : tensor<32xi8> + // CHECK: %[[GLOBAL_LD:.+]] = util.global.load immutable @{{.*}} : tensor<32xi8> // CHECK: %[[ORIG_VAL:.+]] = flow.tensor.bitcast %[[GLOBAL_LD]] : tensor<32xi8> -> tensor<64xi4> // CHECK: util.return %[[ORIG_VAL]] util.func public @main() -> (tensor<64xi4>) { @@ -48,9 +48,9 @@ module @hoist_tree_const_expr_i4 { // CHECK: util.func public @main util.func public @main() -> (tensor<8xi4>, tensor<8xi4>, tensor<8xi4>) { - // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load @[[HOISTED_0]] : tensor<4xi8> + // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load immutable @[[HOISTED_0]] : tensor<4xi8> // CHECK-DAG: %[[BITCAST_0:.*]] = flow.tensor.bitcast %[[LOAD_HOISTED_0]] : tensor<4xi8> -> tensor<8xi4> - // CHECK-DAG: %[[LOAD_HOISTED_1:.*]] = util.global.load @[[HOISTED_1]] : tensor<4xi8> + // CHECK-DAG: %[[LOAD_HOISTED_1:.*]] = util.global.load immutable @[[HOISTED_1]] : tensor<4xi8> // CHECK-DAG: %[[BITCAST_1:.*]] = flow.tensor.bitcast %[[LOAD_HOISTED_1]] : tensor<4xi8> -> tensor<8xi4> // CHECK-DAG: %[[RESULT:.*]] = "iree_unregistered.var_expr"(%[[BITCAST_1]]) // CHECK: util.return %[[BITCAST_0]], %[[BITCAST_1]], %[[RESULT]] @@ -128,7 +128,7 @@ module @hoist_inline_parameters { // CHECK-NEXT: flow.tensor.constant #flow.parameter.named<"compile"::"constant_hoisted_0"> // CHECK-NEXT: "iree_unregistered.const_expr" util.func public @main() -> tensor { - // CHECK: util.global.load @[[HOISTED]] + // CHECK: util.global.load immutable @[[HOISTED]] %parameter = flow.tensor.constant #flow.parameter.named<"compile"::"constant_hoisted_0"> : tensor %0 = "iree_unregistered.const_expr"(%parameter) : (tensor) -> tensor util.return %0 : tensor @@ -142,7 +142,7 @@ module @hoist_inline_parameters { // CHECK-LABEL: @hoist_dialect_attrs module @hoist_dialect_attrs { - // CHECK: util.global private @[[HOISTED:[a-z0-9]+]] + // CHECK: util.global private @[[HOISTED:[a-z0-9_]+]] // CHECK-SAME: hal.affinity = #hal.affinity.queue<[0, 1]> // CHECK: util.initializer // CHECK-SAME: hal.affinity = #hal.affinity.queue<[0, 1]> diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel index 4fb4155c7570..6717abb3f90f 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel @@ -28,7 +28,7 @@ iree_compiler_cc_library( ], deps = [ ":PassesIncGen", - "//compiler/src/iree/compiler/Dialect/Stream/IR", + "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "//runtime/src/iree/base", "//runtime/src/iree/hal", diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt index 7c2730585fdb..41704dd0ceeb 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt @@ -33,7 +33,7 @@ iree_cc_library( MLIRTransformUtils MLIRTransforms iree::base - iree::compiler::Dialect::Stream::IR + iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::Util::IR iree::hal iree::io::file_handle diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp index 4e67fb6016d9..539a850ec8af 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp @@ -4,7 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" @@ -174,7 +174,7 @@ struct ExportParametersPass } // Change the global to reference the parameter. - globalOp.setGlobalInitialValue(IREE::Stream::NamedParameterAttr::get( + globalOp.setGlobalInitialValue(IREE::Flow::NamedParameterAttr::get( context, globalOp.getGlobalType(), StringAttr::get(context, scope), StringAttr::get(context, name), DictionaryAttr())); } diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp index 4a6dfc2bc3ba..a319589404b2 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp @@ -4,7 +4,8 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" @@ -24,6 +25,34 @@ namespace mlir::iree_compiler::IREE::IO::Parameters { namespace { +// Returns a set of all unique parameters and the locations using them. +static SmallVector> +findAllParameters(ModuleOp moduleOp) { + llvm::MapVector> + parameterAttrs; + moduleOp.walk([&](Operation *op) { + if (auto globalOp = dyn_cast(op)) { + if (auto parameterAttr = + dyn_cast_if_present( + globalOp.getGlobalInitialValue())) { + parameterAttrs[parameterAttr].push_back(globalOp.getLoc()); + } + } else if (auto constantOp = dyn_cast(op)) { + if (auto parameterAttr = + dyn_cast_if_present( + constantOp.getValue())) { + parameterAttrs[parameterAttr].push_back(constantOp.getLoc()); + } + } + }); + SmallVector> locAttrs; + for (auto &entry : parameterAttrs) { + locAttrs.push_back(std::make_pair( + FusedLoc::get(moduleOp.getContext(), entry.second), entry.first)); + } + return locAttrs; +} + static Attribute getDefaultSplatAttr(Type elementType) { // Today we only support basic types where 0 bits represent zeros - that lets // us just splat out the right number of bits. @@ -50,20 +79,16 @@ struct GenerateSplatParameterArchivePass if (failed(builder)) return signalPassFailure(); - // Walk the globals in the module. - for (auto globalOp : moduleOp.getOps()) { + // Find all parameters in the module and add them to the builder. + // NOTE: there may be no parameters but we still will create the archive + // so that subsequent tooling that tries to load it succeeds. + auto parameterAttrs = findAllParameters(moduleOp); + for (auto [loc, parameterAttr] : parameterAttrs) { // Only support types we can meaningfully generate splats for. - auto shapedType = dyn_cast(globalOp.getGlobalType()); + auto shapedType = dyn_cast(parameterAttr.getType()); if (!shapedType) continue; - // Look for globals backed by parameters. - auto parameterAttr = - dyn_cast_if_present( - globalOp.getGlobalInitialValue()); - if (!parameterAttr) - continue; - // TODO: support other patterns/generators. auto elementAttr = getDefaultSplatAttr(shapedType.getElementType()); @@ -71,7 +96,7 @@ struct GenerateSplatParameterArchivePass SmallVector pattern; llvm::raw_svector_ostream os(pattern); if (failed(IREE::Util::SerializableAttrInterface::serializeSplatValue( - globalOp.getLoc(), elementAttr, + loc, elementAttr, /*count=*/1, llvm::endianness::little, os))) { return signalPassFailure(); } @@ -94,10 +119,6 @@ struct GenerateSplatParameterArchivePass } } - // Early exit if no parameter backed globals present. - if (iree_io_parameter_archive_builder_is_empty(builder->get())) - return; - // Create the parameter archive file. auto fileStreamIndexOr = createParameterIndex(moduleOp, std::move(builder.value()), filePath); diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp index 8a3888e6a5eb..288495eac3d6 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp @@ -4,7 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" @@ -255,7 +255,7 @@ importParameterFromFile(StringRef fullName, ShapedType globalType, // Import the given |parameterAttr| from |entry|. static FailureOr importParameter(StringRef fullName, ShapedType globalType, - IREE::Stream::NamedParameterAttr parameterAttr, + IREE::Flow::NamedParameterAttr parameterAttr, const iree_io_parameter_index_entry_t *entry) { switch (entry->type) { case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT: @@ -292,7 +292,7 @@ struct ImportParametersPass for (auto &key : keys) importKeys.insert(key); auto shouldImportParameter = - [&](IREE::Stream::NamedParameterAttr parameterAttr) -> bool { + [&](IREE::Flow::NamedParameterAttr parameterAttr) -> bool { // Always try to import explicitly named parameters. if (importKeys.contains(parameterAttr.getKey().getValue())) return true; // key match @@ -308,9 +308,8 @@ struct ImportParametersPass // Find all parameters and try to import them. for (auto globalOp : moduleOp.getOps()) { // Only inspect parameter globals. - auto parameterAttr = - dyn_cast_if_present( - globalOp.getGlobalInitialValue()); + auto parameterAttr = dyn_cast_if_present( + globalOp.getGlobalInitialValue()); if (!parameterAttr) continue; diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td index 00d7dad17bed..603d23f9202e 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td @@ -14,7 +14,7 @@ def ExportParametersPass : let summary = "Exports all global constants to an archive file when " "they are larger than the specified minimum size."; let dependentDialects = [ - "IREE::Stream::StreamDialect", + "IREE::Flow::FlowDialect", "IREE::Util::UtilDialect", ]; let options = [ @@ -42,7 +42,7 @@ def ImportParametersPass : Pass<"iree-io-import-parameters", "mlir::ModuleOp"> { let summary = "Imports parameters from an archive file."; let dependentDialects = [ - "IREE::Stream::StreamDialect", + "IREE::Flow::FlowDialect", "IREE::Util::UtilDialect", ]; let options = [ diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/export_parameters.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/export_parameters.mlir index 81ba9da616b0..7f54149f0ea7 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/export_parameters.mlir +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/export_parameters.mlir @@ -1,42 +1,42 @@ // RUN: iree-opt --pass-pipeline="builtin.module(iree-io-export-parameters{path="opt=%t.irpa" minimum-size=0})" %s | FileCheck %s // RUN: iree-dump-parameters --parameters=%t.irpa | FileCheck %s --check-prefix=DUMP -// CHECK: util.global private @constant_scalar_i1 = #stream.parameter.named<"opt"::"constant_scalar_i1"> : tensor +// CHECK: util.global private @constant_scalar_i1 = #flow.parameter.named<"opt"::"constant_scalar_i1"> : tensor // DUMP: - | - | 1 | `constant_scalar_i1` util.global private @constant_scalar_i1 = dense : tensor -// CHECK-NEXT: util.global private @constant_dense_2xi1 = #stream.parameter.named<"opt"::"constant_dense_2xi1"> : tensor<2xi1> +// CHECK-NEXT: util.global private @constant_dense_2xi1 = #flow.parameter.named<"opt"::"constant_dense_2xi1"> : tensor<2xi1> // DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 2 | `constant_dense_2xi1` util.global private @constant_dense_2xi1 = dense<[true, false]> : tensor<2xi1> -// CHECK-NEXT: util.global private @constant_dense_3xi4 = #stream.parameter.named<"opt"::"constant_dense_3xi4"> : tensor<3xi4> +// CHECK-NEXT: util.global private @constant_dense_3xi4 = #flow.parameter.named<"opt"::"constant_dense_3xi4"> : tensor<3xi4> // DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 2 | `constant_dense_3xi4` util.global private @constant_dense_3xi4 = dense<[4, 5, 6]> : tensor<3xi4> -// CHECK-NEXT: util.global private @constant_dense_2xi8 = #stream.parameter.named<"opt"::"constant_dense_2xi8"> : tensor<2xi8> +// CHECK-NEXT: util.global private @constant_dense_2xi8 = #flow.parameter.named<"opt"::"constant_dense_2xi8"> : tensor<2xi8> // DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 2 | `constant_dense_2xi8` util.global private @constant_dense_2xi8 = dense<[4, 5]> : tensor<2xi8> -// CHECK-NEXT: util.global private @constant_dense_2xf32 = #stream.parameter.named<"opt"::"constant_dense_2xf32"> : tensor<2xf32> +// CHECK-NEXT: util.global private @constant_dense_2xf32 = #flow.parameter.named<"opt"::"constant_dense_2xf32"> : tensor<2xf32> // DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 8 | `constant_dense_2xf32` util.global private @constant_dense_2xf32 = dense<[11.0, 12.0]> : tensor<2xf32> -// CHECK-NEXT: util.global private @constant_splat_2xf32 = #stream.parameter.named<"opt"::"constant_splat_2xf32"> : tensor<2xf32> +// CHECK-NEXT: util.global private @constant_splat_2xf32 = #flow.parameter.named<"opt"::"constant_splat_2xf32"> : tensor<2xf32> // DUMP-NEXT: - | - | 8 | `constant_splat_2xf32` util.global private @constant_splat_2xf32 = dense<11.0> : tensor<2xf32> -// CHECK-NEXT: util.global private mutable @mutable_scalar_i1 = #stream.parameter.named<"opt"::"mutable_scalar_i1"> : tensor +// CHECK-NEXT: util.global private mutable @mutable_scalar_i1 = #flow.parameter.named<"opt"::"mutable_scalar_i1"> : tensor // DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 1 | `mutable_scalar_i1` util.global private mutable @mutable_scalar_i1 = dense : tensor -// CHECK-NEXT: util.global private mutable @mutable_dense_3xi4 = #stream.parameter.named<"opt"::"mutable_dense_3xi4"> : tensor<3xi4> +// CHECK-NEXT: util.global private mutable @mutable_dense_3xi4 = #flow.parameter.named<"opt"::"mutable_dense_3xi4"> : tensor<3xi4> // DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 2 | `mutable_dense_3xi4` util.global private mutable @mutable_dense_3xi4 = dense<[4, 5, 6]> : tensor<3xi4> -// CHECK-NEXT: util.global private mutable @mutable_dense_2xf32 = #stream.parameter.named<"opt"::"mutable_dense_2xf32"> : tensor<2xf32> +// CHECK-NEXT: util.global private mutable @mutable_dense_2xf32 = #flow.parameter.named<"opt"::"mutable_dense_2xf32"> : tensor<2xf32> // DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 8 | `mutable_dense_2xf32` util.global private mutable @mutable_dense_2xf32 = dense<[11.0, 12.0]> : tensor<2xf32> -// CHECK-NEXT: util.global private mutable @mutable_splat_2xf32 = #stream.parameter.named<"opt"::"mutable_splat_2xf32"> : tensor<2xf32> +// CHECK-NEXT: util.global private mutable @mutable_splat_2xf32 = #flow.parameter.named<"opt"::"mutable_splat_2xf32"> : tensor<2xf32> // DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 8 | `mutable_splat_2xf32` util.global private mutable @mutable_splat_2xf32 = dense<11.0> : tensor<2xf32> diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/generate_splat_parameter_archive.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/generate_splat_parameter_archive.mlir index 4944c01cc53d..215d1cca6dbc 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/generate_splat_parameter_archive.mlir +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/generate_splat_parameter_archive.mlir @@ -3,20 +3,27 @@ // CHECK: util.global private @tensor_i1 // DUMP: - | - | 1 | `tensor_i1` -util.global private @tensor_i1 = #stream.parameter.named<"opt"::"tensor_i1"> : tensor +util.global private @tensor_i1 = #flow.parameter.named<"opt"::"tensor_i1"> : tensor // CHECK-NEXT: util.global private @tensor_i8 // DUMP-NEXT: - | - | 1 | `tensor_i8` -util.global private @tensor_i8 = #stream.parameter.named<"opt"::"tensor_i8"> : tensor +util.global private @tensor_i8 = #flow.parameter.named<"opt"::"tensor_i8"> : tensor // CHECK-NEXT: util.global private @tensor_1x2xi32 // DUMP-NEXT: - | - | 8 | `tensor_1x2xi32` -util.global private @tensor_1x2xi32 = #stream.parameter.named<"opt"::"tensor_1x2xi32"> : tensor<1x2xi32> +util.global private @tensor_1x2xi32 = #flow.parameter.named<"opt"::"tensor_1x2xi32"> : tensor<1x2xi32> // CHECK-NEXT: util.global private @tensor_2x2xi4 // DUMP-NEXT: - | - | 2 | `tensor_2x2xi4` -util.global private @tensor_2x2xi4 = #stream.parameter.named<"opt"::"tensor_2x2xi4"> : tensor<2x2xi4> +util.global private @tensor_2x2xi4 = #flow.parameter.named<"opt"::"tensor_2x2xi4"> : tensor<2x2xi4> // CHECK-NEXT: util.global private @tensor_3xi4 // DUMP-NEXT: - | - | 2 | `tensor_3xi4` -util.global private @tensor_3xi4 = #stream.parameter.named<"opt"::"tensor_3xi4"> : tensor<3xi4> +util.global private @tensor_3xi4 = #flow.parameter.named<"opt"::"tensor_3xi4"> : tensor<3xi4> + +util.func private @function() { + // CHECK: flow.tensor.constant + // DUMP-NEXT: - | - | 4 | `inline` + flow.tensor.constant #flow.parameter.named<"opt"::"inline"> : tensor<4xi8> + util.return +} diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/import_parameters.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/import_parameters.mlir index 075b9d0c95fa..45eb7b77924d 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/import_parameters.mlir +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/import_parameters.mlir @@ -1,15 +1,15 @@ // RUN: iree-opt --pass-pipeline="builtin.module(iree-io-export-parameters{path="opt=%t.irpa" minimum-size=0},iree-io-import-parameters{paths="opt=%t.irpa"})" %s | FileCheck %s // NOTE: packed types not supported for import yet. -// CHECK: util.global private @constant_scalar_i1 = #stream.parameter.named +// CHECK: util.global private @constant_scalar_i1 = #flow.parameter.named util.global private @constant_scalar_i1 = dense : tensor // NOTE: packed types not supported for import yet. -// CHECK: util.global private @constant_dense_2xi1 = #stream.parameter.named +// CHECK: util.global private @constant_dense_2xi1 = #flow.parameter.named util.global private @constant_dense_2xi1 = dense<[true, false]> : tensor<2xi1> // NOTE: packed types not supported for import yet. -// CHECK: util.global private @constant_dense_3xi4 = #stream.parameter.named +// CHECK: util.global private @constant_dense_3xi4 = #flow.parameter.named util.global private @constant_dense_3xi4 = dense<[4, 5, 6]> : tensor<3xi4> // CHECK: util.global private @constant_dense_2xi8 = dense<[4, 5]> : tensor<2xi8> @@ -22,11 +22,11 @@ util.global private @constant_dense_2xf32 = dense<[1.100000e+01, 1.200000e+01]> util.global private @constant_splat_2xf32 = dense<1.100000e+01> : tensor<2xf32> // NOTE: packed types not supported for import yet. -// CHECK: util.global private mutable @mutable_scalar_i1 = #stream.parameter.named +// CHECK: util.global private mutable @mutable_scalar_i1 = #flow.parameter.named util.global private mutable @mutable_scalar_i1 = dense : tensor // NOTE: packed types not supported for import yet. -// CHECK: util.global private mutable @mutable_dense_3xi4 = #stream.parameter.named +// CHECK: util.global private mutable @mutable_dense_3xi4 = #flow.parameter.named util.global private mutable @mutable_dense_3xi4 = dense<[4, 5, 6]> : tensor<3xi4> // CHECK: util.global private mutable @mutable_dense_2xf32 = dense<[1.100000e+01, 1.200000e+01]> : tensor<2xf32> diff --git a/runtime/bindings/python/tests/io_runtime_test.py b/runtime/bindings/python/tests/io_runtime_test.py index 417b99de6f9d..3d07e307bfcc 100644 --- a/runtime/bindings/python/tests/io_runtime_test.py +++ b/runtime/bindings/python/tests/io_runtime_test.py @@ -17,10 +17,10 @@ TEST_COMPILED = None TEST_ASM = r""" -util.global private @a0 = #stream.parameter.named<"a"::"a0"> : tensor<4xi64> -util.global private @a1 = #stream.parameter.named<"a"::"a1"> : tensor<4xi64> -util.global private @b0 = #stream.parameter.named<"b"::"b0"> : tensor<8xi64> -util.global private @b1 = #stream.parameter.named<"b"::"b1"> : tensor<8xi64> +util.global private @a0 = #flow.parameter.named<"a"::"a0"> : tensor<4xi64> +util.global private @a1 = #flow.parameter.named<"a"::"a1"> : tensor<4xi64> +util.global private @b0 = #flow.parameter.named<"b"::"b0"> : tensor<8xi64> +util.global private @b1 = #flow.parameter.named<"b"::"b1"> : tensor<8xi64> func.func @echo() -> (tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>) { %a0 = util.global.load @a0 : tensor<4xi64> %a1 = util.global.load @a1 : tensor<4xi64> diff --git a/tests/e2e/parameters/generate_splat_archive.mlir b/tests/e2e/parameters/generate_splat_archive.mlir index a7b360701212..74a6725e7163 100644 --- a/tests/e2e/parameters/generate_splat_archive.mlir +++ b/tests/e2e/parameters/generate_splat_archive.mlir @@ -13,10 +13,10 @@ // CHECK-LABEL: EXEC @main // CHECK: 1x2xi32=[0 0] -util.global private @array_global_0 = #stream.parameter.named<"scope"::"global_0"> : tensor<1x2xi32> -util.global private @dense_global_1 = #stream.parameter.named<"scope"::"global_1"> : tensor<2x2xi32> -util.global private @dense_global_2 = #stream.parameter.named<"scope"::"global_2"> : tensor<1x2xi32> -util.global private @dense_global_3 = #stream.parameter.named<"scope"::"global_3"> : tensor<2x2xi32> +util.global private @array_global_0 = #flow.parameter.named<"scope"::"global_0"> : tensor<1x2xi32> +util.global private @dense_global_1 = #flow.parameter.named<"scope"::"global_1"> : tensor<2x2xi32> +util.global private @dense_global_2 = #flow.parameter.named<"scope"::"global_2"> : tensor<1x2xi32> +util.global private @dense_global_3 = #flow.parameter.named<"scope"::"global_3"> : tensor<2x2xi32> func.func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2xi32> { %cst = arith.constant 0 : i32 %3 = util.global.load @array_global_0 : tensor<1x2xi32> diff --git a/tools/test/parameters_scoped.mlir b/tools/test/parameters_scoped.mlir index 2bc7c9eccef2..9b294ab173b3 100644 --- a/tools/test/parameters_scoped.mlir +++ b/tools/test/parameters_scoped.mlir @@ -14,10 +14,10 @@ // provide content for a single scope but not to have a single file provide // content for multiple scopes. Since parameter keys only need to be unique // within a scope this test could use the same name for both scopes if needed. -util.global private @a0 = #stream.parameter.named<"a"::"a0"> : tensor<4xi64> -util.global private @a1 = #stream.parameter.named<"a"::"a1"> : tensor<4xi64> -util.global private @b0 = #stream.parameter.named<"b"::"b0"> : tensor<8xi64> -util.global private @b1 = #stream.parameter.named<"b"::"b1"> : tensor<8xi64> +util.global private @a0 = #flow.parameter.named<"a"::"a0"> : tensor<4xi64> +util.global private @a1 = #flow.parameter.named<"a"::"a1"> : tensor<4xi64> +util.global private @b0 = #flow.parameter.named<"b"::"b0"> : tensor<8xi64> +util.global private @b1 = #flow.parameter.named<"b"::"b1"> : tensor<8xi64> func.func @echo() -> (tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>) { %a0 = util.global.load @a0 : tensor<4xi64> %a1 = util.global.load @a1 : tensor<4xi64> diff --git a/tools/test/parameters_unscoped.mlir b/tools/test/parameters_unscoped.mlir index 933cb7772aa7..715f6862c5b4 100644 --- a/tools/test/parameters_unscoped.mlir +++ b/tools/test/parameters_unscoped.mlir @@ -12,10 +12,10 @@ // Simple named parameters with no scope. Parameter files are combined at // runtime to allow for filesystem sharding while still providing a flat set of // parameters in the compiler input. -util.global private @a0 = #stream.parameter.named<"a0"> : tensor<4xi64> -util.global private @a1 = #stream.parameter.named<"a1"> : tensor<4xi64> -util.global private @b0 = #stream.parameter.named<"b0"> : tensor<8xi64> -util.global private @b1 = #stream.parameter.named<"b1"> : tensor<8xi64> +util.global private @a0 = #flow.parameter.named<"a0"> : tensor<4xi64> +util.global private @a1 = #flow.parameter.named<"a1"> : tensor<4xi64> +util.global private @b0 = #flow.parameter.named<"b0"> : tensor<8xi64> +util.global private @b1 = #flow.parameter.named<"b1"> : tensor<8xi64> func.func @echo() -> (tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>) { %a0 = util.global.load @a0 : tensor<4xi64> %a1 = util.global.load @a1 : tensor<4xi64>