Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HAL] Add a pass to materialize homogeneous encodings. #14724

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,33 @@ static MatmulTileParams chooseMatmulTileParams(EncodingUser user,

struct CPUMaterializeEncodingPass
: public CPUMaterializeEncodingBase<CPUMaterializeEncodingPass> {
CPUMaterializeEncodingPass() : targetAttr(nullptr) {}
explicit CPUMaterializeEncodingPass(IREE::HAL::ExecutableTargetAttr attr)
: targetAttr(attr) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, IREE::LinalgExt::IREELinalgExtDialect,
IREE::Codegen::IREECodegenDialect>();
}
void runOnOperation() override;

private:
IREE::HAL::ExecutableTargetAttr targetAttr;
};

struct CPUMaterializeUpperBoundTileSizePass
: public CPUMaterializeUpperBoundTileSizeBase<
CPUMaterializeUpperBoundTileSizePass> {
CPUMaterializeUpperBoundTileSizePass() = default;
explicit CPUMaterializeUpperBoundTileSizePass(
ArrayRef<IREE::HAL::ExecutableTargetAttr> attrs)
: targetAttrs(attrs) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect>();
}
void runOnOperation() override;

private:
SmallVector<IREE::HAL::ExecutableTargetAttr, 4> targetAttrs;
};

FailureOr<MaterializeEncodingInfo>
Expand Down Expand Up @@ -235,7 +248,8 @@ void CPUMaterializeEncodingPass::runOnOperation() {
MLIRContext *context = &getContext();
auto operation = getOperation();
RewritePatternSet materializeEncodingPattern(context);
auto targetAttr = ExecutableTargetAttr::lookup(operation);
if (!targetAttr)
targetAttr = ExecutableTargetAttr::lookup(operation);
auto materializeEncodingFn = getMaterializeEncodingFn(targetAttr);
if (!materializeEncodingFn) {
return signalPassFailure();
Expand Down Expand Up @@ -269,8 +283,10 @@ void CPUMaterializeEncodingPass::runOnOperation() {
void CPUMaterializeUpperBoundTileSizePass::runOnOperation() {
MLIRContext *context = &getContext();
auto operation = getOperation();
auto targetAttrs =
IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(operation);
if (targetAttrs.empty()) {
targetAttrs =
IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(operation);
}
RewritePatternSet patterns(context);
MaterializeEncodingFn materializeEncodingFn =
getUpperBoundMaterializeEncodingFn(targetAttrs);
Expand All @@ -290,11 +306,20 @@ std::unique_ptr<OperationPass<func::FuncOp>>
createCPUMaterializeEncodingPass() {
return std::make_unique<CPUMaterializeEncodingPass>();
}
std::unique_ptr<OperationPass<func::FuncOp>>
createCPUMaterializeEncodingPass(IREE::HAL::ExecutableTargetAttr targetAttr) {
return std::make_unique<CPUMaterializeEncodingPass>(targetAttr);
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createCPUMaterializeUpperBoundTileSizePass() {
return std::make_unique<CPUMaterializeUpperBoundTileSizePass>();
}
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createCPUMaterializeUpperBoundTileSizePass(
ArrayRef<IREE::HAL::ExecutableTargetAttr> targetAttrs) {
return std::make_unique<CPUMaterializeUpperBoundTileSizePass>(targetAttrs);
}

} // namespace iree_compiler
} // namespace mlir
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef IREE_COMPILER_CODEGEN_COMMON_CPU_PASSES_H_
#define IREE_COMPILER_CODEGEN_COMMON_CPU_PASSES_H_

#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"

Expand All @@ -23,6 +24,8 @@ namespace iree_compiler {
/// linalg_ext.unset_encoding -> tensor.unpack
/// linalg.matmul -> linalg.mmt4d
std::unique_ptr<OperationPass<func::FuncOp>> createCPUMaterializeEncodingPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createCPUMaterializeEncodingPass(IREE::HAL::ExecutableTargetAttr targetAttr);

/// Like createLLVMCPUMaterializeEncodingPass, but specifically for
/// linalg_ext.upper_bound_tile_size, converting it to constants.
Expand All @@ -41,6 +44,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createCPUMaterializeEncodingPass();
// as needed.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createCPUMaterializeUpperBoundTileSizePass();
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createCPUMaterializeUpperBoundTileSizePass(
ArrayRef<IREE::HAL::ExecutableTargetAttr> targetAttrs);

void registerCodegenCommonCPUPasses();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ iree_compiler_cc_library(
":LinkerTool",
":StaticLibraryGenerator",
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses",
"//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/LLVMCPU",
"//compiler/src/iree/compiler/Codegen/Utils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ iree_cc_library(
MLIRTargetLLVMIRExport
MLIRTransformDialect
iree::compiler::Codegen::Common
iree::compiler::Codegen::Common::CPU::CommonCPUPasses
iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Codegen::LLVMCPU
iree::compiler::Codegen::Utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
Expand Down Expand Up @@ -218,6 +219,15 @@ class LLVMCPUTargetBackend final : public TargetBackend {
buildLLVMCPUCodegenPassPipeline(passManager);
}

void buildMaterializeEncodingsPassPipeline(
IREE::HAL::ExecutableTargetAttr executableTarget,
OpPassManager &passManager) override {
passManager.addNestedPass<func::FuncOp>(
createCPUMaterializeUpperBoundTileSizePass(executableTarget));
passManager.addNestedPass<func::FuncOp>(
createCPUMaterializeEncodingPass({executableTarget}));
}

void buildLinkingPassPipeline(OpPassManager &passManager) override {
buildLLVMCPULinkingPassPipeline(passManager);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ class TargetBackend {
buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
OpPassManager &passManager) = 0;

virtual void buildMaterializeEncodingsPassPipeline(
IREE::HAL::ExecutableTargetAttr executableTarget,
OpPassManager &passManager) {}

// Inserts passes used to link `hal.executable.variant` ops together.
// The pass manager will be nested on the parent module of `hal.executable`
// ops and the pipeline will need to find relevant variant ops itself.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_compiler_cc_library(
"InlineDeviceSwitches.cpp",
"LinkExecutables.cpp",
"MaterializeDispatchInstrumentation.cpp",
"MaterializeHomogeneousEncodings.cpp",
"MaterializeInterfaces.cpp",
"MaterializeResourceCaches.cpp",
"MemoizeDeviceQueries.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ iree_cc_library(
"InlineDeviceSwitches.cpp"
"LinkExecutables.cpp"
"MaterializeDispatchInstrumentation.cpp"
"MaterializeHomogeneousEncodings.cpp"
"MaterializeInterfaces.cpp"
"MaterializeResourceCaches.cpp"
"MemoizeDeviceQueries.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {

class MaterializeHomogeneousEncodingsPass
: public PassWrapper<MaterializeHomogeneousEncodingsPass,
OperationPass<ModuleOp>> {
public:
MaterializeHomogeneousEncodingsPass()
: targetRegistry(TargetBackendRegistry::getGlobal()) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::HAL::HALDialect>();
}

StringRef getArgument() const override {
return "iree-hal-materialize-homogeneous-encodings";
}

StringRef getDescription() const override {
return "Mateiralizes logical encodings to physical encodings if there is "
"a single device target.";
}

void runOnOperation() override {
auto moduleOp = getOperation();
auto targetsAttr = moduleOp->getAttrOfType<ArrayAttr>("hal.device.targets");
if (!targetsAttr || targetsAttr.size() != 1) {
return;
}
auto deviceTarget = cast<IREE::HAL::DeviceTargetAttr>(targetsAttr[0]);
SmallVector<ExecutableTargetAttr, 4> executableTargets =
deviceTarget.getExecutableTargets();
if (executableTargets.size() != 1) {
return;
}
auto executableTarget = executableTargets[0];
OpPassManager passManager(moduleOp.getOperationName());
auto targetBackend =
targetRegistry.getTargetBackend(executableTarget.getBackend());
if (!targetBackend) {
moduleOp.emitError() << "unregistered target backend '"
<< executableTarget.getBackend() << "'";
return;
}
targetBackend->buildMaterializeEncodingsPassPipeline(executableTarget,
passManager);
if (failed(runPipeline(passManager, moduleOp))) {
return signalPassFailure();
}
}

private:
const TargetBackendRegistry &targetRegistry;
};

std::unique_ptr<OperationPass<ModuleOp>>
createMaterializeHomogeneousEncodingsPass() {
return std::make_unique<MaterializeHomogeneousEncodingsPass>();
}

static PassRegistration<MaterializeHomogeneousEncodingsPass> pass([] {
return std::make_unique<MaterializeHomogeneousEncodingsPass>();
});

} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,9 @@ void buildHALTransformPassPipeline(OpPassManager &passManager,
// Executable translation
//----------------------------------------------------------------------------

FunctionLikeNest(passManager)
.addPass(createCPUMaterializeUpperBoundTileSizePass);
FunctionLikeNest(passManager).addPass([]() {
return createCPUMaterializeUpperBoundTileSizePass();
});

// Preprocess executables using an external tool. The tool may mutate one or
// more variants and even insert or remove variants.
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ std::unique_ptr<OperationPass<mlir::ModuleOp>> createMemoizeDeviceQueriesPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createMaterializeInterfacesPass();

std::unique_ptr<OperationPass<mlir::ModuleOp>>
createMaterializeHomogeneousEncodingsPass();

// Dumps individual hal.executable source listings to |path|.
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createDumpExecutableSourcesPass(StringRef path);
Expand Down Expand Up @@ -213,6 +216,7 @@ inline void registerHALPasses() {
createLinkExecutablesPass(TargetBackendRegistry::getGlobal());
createLinkTargetExecutablesPass(TargetBackendRegistry::getGlobal(), "");
createMaterializeDispatchInstrumentationPass(0);
createMaterializeHomogeneousEncodingsPass();
createMaterializeInterfacesPass();
createMaterializeResourceCachesPass(targetOptions);
createMemoizeDeviceQueriesPass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"fixup_legacy_sync.mlir",
"inline_device_switches.mlir",
"materialize_dispatch_instrumentation.mlir",
"materialize_homogeneous_encodings.mlir",
"materialize_interfaces.mlir",
"materialize_resource_caches.mlir",
"memoize_device_queries.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_lit_test_suite(
"fixup_legacy_sync.mlir"
"inline_device_switches.mlir"
"materialize_dispatch_instrumentation.mlir"
"materialize_homogeneous_encodings.mlir"
"materialize_interfaces.mlir"
"materialize_resource_caches.mlir"
"memoize_device_queries.mlir"
Expand Down
Loading
Loading