diff --git a/compiler/src/iree/compiler/Reducer/Strategies/ReduceFlowDispatchOperandToResult.cpp b/compiler/src/iree/compiler/Reducer/Strategies/ReduceFlowDispatchOperandToResult.cpp index d0477a90b28d..22b8db3ea0be 100644 --- a/compiler/src/iree/compiler/Reducer/Strategies/ReduceFlowDispatchOperandToResult.cpp +++ b/compiler/src/iree/compiler/Reducer/Strategies/ReduceFlowDispatchOperandToResult.cpp @@ -42,16 +42,18 @@ void mlir::iree_compiler::Reducer::reduceFlowDispatchOperandToResultDelta( return; } - // Replace all dispatch ops with random inputs. + // Replace all dispatch ops with the chosen operand. for (auto [result, operand] : resultToOperand) { result.replaceAllUsesWith(operand); } - // Simplify. PassManager pm(module.getContext()); - pm.addPass(createCanonicalizerPass()); + // Dead code eliminate the dispatch ops. pm.addPass(createCSEPass()); + // Dead code eliminate the weights. pm.addPass(createSymbolDCEPass()); + // Canonicalize the module. + pm.addPass(createCanonicalizerPass()); if (failed(pm.run(module))) { return; } diff --git a/compiler/src/iree/compiler/Reducer/Strategies/ReduceFlowDispatchResultBySplat.cpp b/compiler/src/iree/compiler/Reducer/Strategies/ReduceFlowDispatchResultBySplat.cpp index f954e804a66e..c90f0fe6bcf4 100644 --- a/compiler/src/iree/compiler/Reducer/Strategies/ReduceFlowDispatchResultBySplat.cpp +++ b/compiler/src/iree/compiler/Reducer/Strategies/ReduceFlowDispatchResultBySplat.cpp @@ -80,4 +80,17 @@ void mlir::iree_compiler::Reducer::reduceFlowDispatchResultBySplatDelta( // Erase the dispatch. dispatch.erase(); } + + PassManager pm(module.getContext()); + // Dead code eliminate the dispatch ops. + pm.addPass(createCSEPass()); + // Dead code eliminate globals. + pm.addPass(createSymbolDCEPass()); + // Canonicalize so that the splats are fused with reshapes. + pm.addPass(createCanonicalizerPass()); + // CSE again to de-duplicate splats. + pm.addPass(createCSEPass()); + if (failed(pm.run(module))) { + return; + } } diff --git a/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp b/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp index 2b755b822a49..d85c990eb6ad 100644 --- a/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp +++ b/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp @@ -121,8 +121,12 @@ void mlir::iree_compiler::Reducer::reduceLinalgOnTensorsDelta( } PassManager pm(module.getContext()); - pm.addPass(createCanonicalizerPass()); + // Dead code eliminate. pm.addPass(createCSEPass()); + // De-duplicate identical fills. + pm.addPass(createCanonicalizerPass()); + // Remove dead globals. + pm.addPass(createSymbolDCEPass()); if (failed(pm.run(module))) return; }