From 0d85eafbee40c4887039466c787e24f7fca2f6e3 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Mon, 30 Sep 2024 11:11:43 -0400 Subject: [PATCH 01/45] Add pattern --- mlir/include/Quantum/Transforms/Patterns.h | 1 + mlir/lib/Quantum/Transforms/CMakeLists.txt | 1 + .../Transforms/MergeRotationsPatterns.cpp | 52 +++++++++++++++++++ 3 files changed, 54 insertions(+) create mode 100644 mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp diff --git a/mlir/include/Quantum/Transforms/Patterns.h b/mlir/include/Quantum/Transforms/Patterns.h index 3005406c8a..c023fd6f81 100644 --- a/mlir/include/Quantum/Transforms/Patterns.h +++ b/mlir/include/Quantum/Transforms/Patterns.h @@ -26,6 +26,7 @@ void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSe void populateQIRConversionPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &); void populateAdjointPatterns(mlir::RewritePatternSet &); void populateSelfInversePatterns(mlir::RewritePatternSet &); +void populateMergeRotationsPatterns(mlir::RewritePatternSet &); } // namespace quantum } // namespace catalyst diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index edd440adc1..96ba30d23e 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ file(GLOB SRC remove_chained_self_inverse.cpp SplitMultipleTapes.cpp merge_rotation.cpp + MergeRotationsPatterns.cpp ) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp new file mode 100644 index 0000000000..b9ad9e03df --- /dev/null +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -0,0 +1,52 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "merge-rotations" + +#include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/Patterns.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Errc.h" +using llvm::dbgs; +using namespace mlir; +using namespace catalyst; +using namespace catalyst::quantum; + + +namespace { + +struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(CustomOp op, mlir::PatternRewriter &rewriter) const override + { + LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); + + return success(); + } +}; + +} // namespace + +namespace catalyst { +namespace quantum { + +void populateMergeRotationsPatterns(RewritePatternSet &patterns) +{ + patterns.add(patterns.getContext(), 1); +} + +} // namespace quantum +} // namespace catalyst From 248ab5e4424fa6f2f746d199435013b582af4518 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Mon, 30 Sep 2024 11:38:43 -0400 Subject: [PATCH 02/45] Update structure --- mlir/include/Quantum/Transforms/Passes.h | 2 +- mlir/include/Quantum/Transforms/Passes.td | 4 +- .../Catalyst/Transforms/RegisterAllPasses.cpp | 3 +- .../lib/Quantum/Transforms/merge_rotation.cpp | 38 +++++++++++++++---- 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/mlir/include/Quantum/Transforms/Passes.h b/mlir/include/Quantum/Transforms/Passes.h index 2e241bcd7d..23f3426b89 100644 --- a/mlir/include/Quantum/Transforms/Passes.h +++ b/mlir/include/Quantum/Transforms/Passes.h @@ -28,6 +28,6 @@ std::unique_ptr createAdjointLoweringPass(); std::unique_ptr createRemoveChainedSelfInversePass(); std::unique_ptr createAnnotateFunctionPass(); std::unique_ptr createSplitMultipleTapesPass(); -std::unique_ptr createMergeRotationPass(); +std::unique_ptr createMergeRotationsPass(); } // namespace catalyst diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td index 202347c150..e9e71d3184 100644 --- a/mlir/include/Quantum/Transforms/Passes.td +++ b/mlir/include/Quantum/Transforms/Passes.td @@ -110,10 +110,10 @@ def RemoveChainedSelfInversePass : Pass<"remove-chained-self-inverse"> { let options = QuantumCircuitTransformationPass.options; } -def MergeRotationPass : Pass<"merge-rotation"> { +def MergeRotationPass : Pass<"merge-rotations"> { let summary = "merge rotation boilerplate words"; - let constructor = "catalyst::createMergeRotationPass()"; + let constructor = "catalyst::createMergeRotationsPass()"; let options = !listconcat( QuantumCircuitTransformationPass.options, [ diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index f629ce18f7..02ce02ec35 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -40,13 +40,14 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createHloCustomCallLoweringPass); mlir::registerPass(catalyst::createMemrefCopyToLinalgCopyPass); mlir::registerPass(catalyst::createMemrefToLLVMWithTBAAPass); - mlir::registerPass(catalyst::createMergeRotationPass); + mlir::registerPass(catalyst::createMergeRotationsPass); mlir::registerPass(catalyst::createMitigationLoweringPass); mlir::registerPass(catalyst::createQnodeToAsyncLoweringPass); mlir::registerPass(catalyst::createQuantumBufferizationPass); mlir::registerPass(catalyst::createQuantumConversionPass); mlir::registerPass(catalyst::createRegisterInactiveCallbackPass); mlir::registerPass(catalyst::createRemoveChainedSelfInversePass); + mlir::registerPass(catalyst::createMergeRotationsPass); mlir::registerPass(catalyst::createScatterLoweringPass); mlir::registerPass(catalyst::createSplitMultipleTapesPass); mlir::registerPass(catalyst::createTestPass); diff --git a/mlir/lib/Quantum/Transforms/merge_rotation.cpp b/mlir/lib/Quantum/Transforms/merge_rotation.cpp index 9e3753f9da..0fc5c997a7 100644 --- a/mlir/lib/Quantum/Transforms/merge_rotation.cpp +++ b/mlir/lib/Quantum/Transforms/merge_rotation.cpp @@ -14,11 +14,13 @@ #define DEBUG_TYPE "merge-rotation" -#include "mlir/Pass/Pass.h" -#include "llvm/Support/Debug.h" - #include "Catalyst/IR/CatalystDialect.h" #include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/Patterns.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" using namespace llvm; using namespace mlir; @@ -39,18 +41,38 @@ struct MergeRotationPass : impl::MergeRotationPassBase { LLVM_DEBUG(dbgs() << "merge rotation pass" << "\n"); - if (MyOption == "aloha") { - llvm::errs() << "merge rotation pass, aloha!\n"; + Operation *module = getOperation(); + Operation *targetfunc; + + WalkResult result = module->walk([&](func::FuncOp op) { + StringRef funcName = op.getSymName(); + + if (funcName != FuncNameOpt) { + // not the function to run the pass on, visit the next function + return WalkResult::advance(); + } + targetfunc = op; + return WalkResult::interrupt(); + }); + + if (!result.wasInterrupted()) { + // Never met a target function + // Do nothing and exit! + return; } - else { - llvm::errs() << "merge rotation pass, hi!\n"; + + RewritePatternSet patterns(&getContext()); + populateMergeRotationsPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(targetfunc, std::move(patterns)))) { + return signalPassFailure(); } } }; } // namespace quantum -std::unique_ptr createMergeRotationPass() +std::unique_ptr createMergeRotationsPass() { return std::make_unique(); } From 9ccb01aa98c5874601fe833ffec5edd69d8e64b9 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Thu, 3 Oct 2024 14:57:24 -0400 Subject: [PATCH 03/45] Update --- mlir/include/Quantum/Transforms/Passes.td | 2 +- .../Transforms/MergeRotationsPatterns.cpp | 34 +++++++++++++++++++ .../lib/Quantum/Transforms/merge_rotation.cpp | 10 +++--- mlir/test/Quantum/MergeRotationsTest.mlir | 34 +++++++++++++++++++ 4 files changed, 74 insertions(+), 6 deletions(-) create mode 100644 mlir/test/Quantum/MergeRotationsTest.mlir diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td index e9e71d3184..03bb771984 100644 --- a/mlir/include/Quantum/Transforms/Passes.td +++ b/mlir/include/Quantum/Transforms/Passes.td @@ -110,7 +110,7 @@ def RemoveChainedSelfInversePass : Pass<"remove-chained-self-inverse"> { let options = QuantumCircuitTransformationPass.options; } -def MergeRotationPass : Pass<"merge-rotations"> { +def MergeRotationsPass : Pass<"merge-rotations"> { let summary = "merge rotation boilerplate words"; let constructor = "catalyst::createMergeRotationsPass()"; diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index b9ad9e03df..cd35e6550b 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -16,6 +16,7 @@ #include "Quantum/IR/QuantumOps.h" #include "Quantum/Transforms/Patterns.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Errc.h" @@ -24,6 +25,7 @@ using namespace mlir; using namespace catalyst; using namespace catalyst::quantum; +static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ"}; namespace { @@ -33,7 +35,39 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite(CustomOp op, mlir::PatternRewriter &rewriter) const override { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); + auto loc = op.getLoc(); + StringRef OpGateName = op.getGateName(); + if (!rotationsSet.contains(OpGateName)) + return failure(); + ValueRange InQubits = op.getInQubits(); + auto parentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); + if (!parentOp || parentOp.getGateName() != OpGateName) + return failure(); + ValueRange ParentOutQubits = parentOp.getOutQubits(); + // Check if the input qubits to the current operation match the output qubits of the parent. + for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { + if (Qubit.getDefiningOp() != parentOp || Qubit != ParentOutQubits[Idx]) + return failure(); + } + // Add angles + // Create new op + // Replace + if (OpGateName == "RX" || OpGateName == "RY" || OpGateName == "RZ") { + TypeRange OutQubitsTypes = op.getOutQubits().getTypes(); + TypeRange OutQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); + auto parentParams = parentOp.getParams().front(); + auto params = op.getParams().front(); + Value sumParams = rewriter.create(loc, parentParams, params).getResult(); + ValueRange parentInQubits = parentOp.getInQubits(); + ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); + ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); + auto mergeOp = rewriter.create(loc, OutQubitsTypes, OutQubitsCtrlTypes, sumParams, parentInQubits, OpGateName, nullptr, parentInCtrlQubits, parentInCtrlValues); + ModuleOp mod = op->getParentOfType(); + mod.dump(); + op.replaceAllUsesWith(mergeOp); + op.erase(); + } return success(); } }; diff --git a/mlir/lib/Quantum/Transforms/merge_rotation.cpp b/mlir/lib/Quantum/Transforms/merge_rotation.cpp index 0fc5c997a7..45819b0247 100644 --- a/mlir/lib/Quantum/Transforms/merge_rotation.cpp +++ b/mlir/lib/Quantum/Transforms/merge_rotation.cpp @@ -29,12 +29,12 @@ using namespace catalyst::quantum; namespace catalyst { namespace quantum { -#define GEN_PASS_DEF_MERGEROTATIONPASS -#define GEN_PASS_DECL_MERGEROTATIONPASS +#define GEN_PASS_DEF_MERGEROTATIONSPASS +#define GEN_PASS_DECL_MERGEROTATIONSPASS #include "Quantum/Transforms/Passes.h.inc" -struct MergeRotationPass : impl::MergeRotationPassBase { - using MergeRotationPassBase::MergeRotationPassBase; +struct MergeRotationsPass : impl::MergeRotationsPassBase { + using MergeRotationsPassBase::MergeRotationsPassBase; void runOnOperation() final { @@ -74,7 +74,7 @@ struct MergeRotationPass : impl::MergeRotationPassBase { std::unique_ptr createMergeRotationsPass() { - return std::make_unique(); + return std::make_unique(); } } // namespace catalyst diff --git a/mlir/test/Quantum/MergeRotationsTest.mlir b/mlir/test/Quantum/MergeRotationsTest.mlir new file mode 100644 index 0000000000..a5cdb09fca --- /dev/null +++ b/mlir/test/Quantum/MergeRotationsTest.mlir @@ -0,0 +1,34 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt --pass-pipeline="builtin.module(merge-rotations{func-name=test_merge_rotations})" --split-input-file -verify-diagnostics %s | FileCheck %s + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64) -> !quantum.bit { + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.custom "RX"(%arg0) %1 : !quantum.bit + %3 = quantum.custom "RX"(%arg1) %2 : !quantum.bit + return %3 : !quantum.bit +} + +// ----- + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> !quantum.bit { + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.custom "RX"(%arg0) %1 : !quantum.bit + %3 = quantum.custom "RX"(%arg1) %2 : !quantum.bit + %4 = quantum.custom "RX"(%arg2) %3 : !quantum.bit + return %3 : !quantum.bit +} From 3e5d1d133f4448920a36a0f319708e66c611433d Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Thu, 3 Oct 2024 15:22:42 -0400 Subject: [PATCH 04/45] Working draft --- mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp | 7 ++----- mlir/test/Quantum/MergeRotationsTest.mlir | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index cd35e6550b..9f7e14ef93 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -50,9 +50,7 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { if (Qubit.getDefiningOp() != parentOp || Qubit != ParentOutQubits[Idx]) return failure(); } - // Add angles - // Create new op - // Replace + if (OpGateName == "RX" || OpGateName == "RY" || OpGateName == "RZ") { TypeRange OutQubitsTypes = op.getOutQubits().getTypes(); TypeRange OutQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); @@ -63,10 +61,9 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); auto mergeOp = rewriter.create(loc, OutQubitsTypes, OutQubitsCtrlTypes, sumParams, parentInQubits, OpGateName, nullptr, parentInCtrlQubits, parentInCtrlValues); - ModuleOp mod = op->getParentOfType(); - mod.dump(); op.replaceAllUsesWith(mergeOp); op.erase(); + parentOp.erase(); } return success(); } diff --git a/mlir/test/Quantum/MergeRotationsTest.mlir b/mlir/test/Quantum/MergeRotationsTest.mlir index a5cdb09fca..9485b1e743 100644 --- a/mlir/test/Quantum/MergeRotationsTest.mlir +++ b/mlir/test/Quantum/MergeRotationsTest.mlir @@ -30,5 +30,5 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> !quantum. %2 = quantum.custom "RX"(%arg0) %1 : !quantum.bit %3 = quantum.custom "RX"(%arg1) %2 : !quantum.bit %4 = quantum.custom "RX"(%arg2) %3 : !quantum.bit - return %3 : !quantum.bit + return %4 : !quantum.bit } From c4f46baf6f457dffed79ee83ccbab2e695f1aca3 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Thu, 3 Oct 2024 16:31:51 -0400 Subject: [PATCH 05/45] Update --- .../Quantum/Transforms/MergeRotationsPatterns.cpp | 4 +++- mlir/test/Quantum/MergeRotationsTest.mlir | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index 9f7e14ef93..129027e90b 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -60,7 +60,9 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { ValueRange parentInQubits = parentOp.getInQubits(); ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); - auto mergeOp = rewriter.create(loc, OutQubitsTypes, OutQubitsCtrlTypes, sumParams, parentInQubits, OpGateName, nullptr, parentInCtrlQubits, parentInCtrlValues); + auto mergeOp = rewriter.create(loc, OutQubitsTypes, OutQubitsCtrlTypes, + sumParams, parentInQubits, OpGateName, nullptr, + parentInCtrlQubits, parentInCtrlValues); op.replaceAllUsesWith(mergeOp); op.erase(); parentOp.erase(); diff --git a/mlir/test/Quantum/MergeRotationsTest.mlir b/mlir/test/Quantum/MergeRotationsTest.mlir index 9485b1e743..63deb80050 100644 --- a/mlir/test/Quantum/MergeRotationsTest.mlir +++ b/mlir/test/Quantum/MergeRotationsTest.mlir @@ -17,8 +17,14 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64) -> !quantum.bit { %0 = quantum.alloc( 1) : !quantum.reg %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[reg:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[sum:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: [[ret:%.+]] = quantum.custom "RX"([[sum]]) [[qubit]] : !quantum.bit + // CHECK-NOT: quantum.custom "RX" %2 = quantum.custom "RX"(%arg0) %1 : !quantum.bit %3 = quantum.custom "RX"(%arg1) %2 : !quantum.bit + // CHECK: return [[ret]] return %3 : !quantum.bit } @@ -27,8 +33,16 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64) -> !quantum.bit { func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> !quantum.bit { %0 = quantum.alloc( 1) : !quantum.reg %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + // CHECK: [[reg:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[sum1:%.+]] = arith.addf %arg1, %arg2 : f64 + // CHECK: [[sum2:%.+]] = arith.addf %arg0, [[sum1]] : f64 + // CHECK: [[ret:%.+]] = quantum.custom "RX"([[sum2]]) [[qubit]] : !quantum.bit + // CHECK-NOT: quantum.custom "RX" %2 = quantum.custom "RX"(%arg0) %1 : !quantum.bit %3 = quantum.custom "RX"(%arg1) %2 : !quantum.bit %4 = quantum.custom "RX"(%arg2) %3 : !quantum.bit + // CHECK: return [[ret]] return %4 : !quantum.bit } From 7286abcea8fd67c341bf307dc7c5c6316f658e50 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 7 Oct 2024 10:35:33 -0400 Subject: [PATCH 06/45] renamed to `ChainedNamedHermitianOpRewritePattern` --- mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index 2e6486fb7a..35272bfffe 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -29,7 +29,7 @@ static const mlir::StringSet<> HermitianOps = {"Hadamard", "PauliX", "PauliY", " namespace { -struct ChainedHadamardOpRewritePattern : public mlir::OpRewritePattern { +struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; /// We simplify consecutive Hermitian quantum gates by removing them. @@ -68,7 +68,7 @@ namespace quantum { void populateSelfInversePatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext(), 1); + patterns.add(patterns.getContext(), 1); } } // namespace quantum From e1f54e5c893d8978d7c5d9af7c4948fe6c7ce887 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Mon, 7 Oct 2024 11:45:10 -0400 Subject: [PATCH 07/45] Add test --- .../Transforms/MergeRotationsPatterns.cpp | 7 +++++-- mlir/test/Quantum/MergeRotationsTest.mlir | 21 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index 129027e90b..5a367057b4 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -25,7 +25,7 @@ using namespace mlir; using namespace catalyst; using namespace catalyst::quantum; -static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ"}; +static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "CRX", "CRY", "CRZ"}; namespace { @@ -51,7 +51,10 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { return failure(); } - if (OpGateName == "RX" || OpGateName == "RY" || OpGateName == "RZ") { + + static const mlir::StringSet<> rotationsSetCase1 = {"RX", "RY", "RZ", "CRX", "CRY", "CRZ"}; + + if (rotationsSetCase1.find(OpGateName) != rotationsSetCase1.end()) { TypeRange OutQubitsTypes = op.getOutQubits().getTypes(); TypeRange OutQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); auto parentParams = parentOp.getParams().front(); diff --git a/mlir/test/Quantum/MergeRotationsTest.mlir b/mlir/test/Quantum/MergeRotationsTest.mlir index 63deb80050..c285e82e8b 100644 --- a/mlir/test/Quantum/MergeRotationsTest.mlir +++ b/mlir/test/Quantum/MergeRotationsTest.mlir @@ -46,3 +46,24 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> !quantum. // CHECK: return [[ret]] return %4 : !quantum.bit } + +// ----- + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit) { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + + // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[qubit2:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[sum1:%.+]] = arith.addf %arg1, %arg2 : f64 + // CHECK: [[sum2:%.+]] = arith.addf %arg0, [[sum1]] : f64 + // CHECK: [[ret:%.+]]:2 = quantum.custom "CRX"([[sum2]]) [[qubit1]], [[qubit2]] : !quantum.bit, !quantum.bit + // CHECK-NOT: quantum.custom "CRX" + %3:2 = quantum.custom "CRX"(%arg0) %1, %2: !quantum.bit, !quantum.bit + %4:2 = quantum.custom "CRX"(%arg1) %3#0, %3#1 : !quantum.bit, !quantum.bit + %5:2 = quantum.custom "CRX"(%arg2) %4#0, %4#1 : !quantum.bit, !quantum.bit + // CHECK: return [[ret]]#0, [[ret]]#1 + return %5#0, %5#1 : !quantum.bit, !quantum.bit +} \ No newline at end of file From 043a5f0a38f2f815773f3d33ea849a7a55cc1b9c Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Mon, 7 Oct 2024 11:51:12 -0400 Subject: [PATCH 08/45] MLIR test: CRY switch qubits --- mlir/test/Quantum/MergeRotationsTest.mlir | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mlir/test/Quantum/MergeRotationsTest.mlir b/mlir/test/Quantum/MergeRotationsTest.mlir index c285e82e8b..7b8c214356 100644 --- a/mlir/test/Quantum/MergeRotationsTest.mlir +++ b/mlir/test/Quantum/MergeRotationsTest.mlir @@ -66,4 +66,25 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum %5:2 = quantum.custom "CRX"(%arg2) %4#0, %4#1 : !quantum.bit, !quantum.bit // CHECK: return [[ret]]#0, [[ret]]#1 return %5#0, %5#1 : !quantum.bit, !quantum.bit +} + +// ----- + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit) { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + + // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[qubit2:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[sum:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: [[qubits3:%.+]]:2 = quantum.custom "CRY"([[sum]]) [[qubit1]], [[qubit2]] : !quantum.bit, !quantum.bit + // CHECK: [[ret:%.+]]:2 = quantum.custom "CRY"(%arg2) [[qubits3]]#1, [[qubits3]]#0 : !quantum.bit, !quantum.bit + // CHECK-NOT: quantum.custom "CRY" + %3:2 = quantum.custom "CRY"(%arg0) %1, %2: !quantum.bit, !quantum.bit + %4:2 = quantum.custom "CRY"(%arg1) %3#0, %3#1 : !quantum.bit, !quantum.bit + %5:2 = quantum.custom "CRY"(%arg2) %4#1, %4#0 : !quantum.bit, !quantum.bit + // CHECK: return [[ret]]#0, [[ret]]#1 + return %5#0, %5#1 : !quantum.bit, !quantum.bit } \ No newline at end of file From 1f879437d6719e7fc1c78929c8811ff9e28f7fec Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 7 Oct 2024 13:23:21 -0400 Subject: [PATCH 09/45] add pattern --- .../Transforms/ChainedSelfInversePatterns.cpp | 84 ++++++++++++++++++- mlir/test/Quantum/ChainedSelfInverseTest.mlir | 18 ++++ 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index 35272bfffe..f18b1d3398 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -41,19 +41,22 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern(InQubits[0].getDefiningOp()); - if (!ParentOp || ParentOp.getGateName() != OpGateName) + if (!ParentOp || ParentOp.getGateName() != OpGateName){ return failure(); + } ValueRange ParentOutQubits = ParentOp.getOutQubits(); // Check if the input qubits to the current operation match the output qubits of the parent. for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { - if (Qubit.getDefiningOp() != ParentOp || Qubit != ParentOutQubits[Idx]) + if (Qubit.getDefiningOp() != ParentOp || Qubit != ParentOutQubits[Idx]){ return failure(); + } } ValueRange simplifiedVal = ParentOp.getInQubits(); rewriter.replaceOp(op, simplifiedVal); @@ -61,6 +64,79 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern +struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + bool verifyParentGateType(OpType op, OpType parentOp) const { + // Verify that the parent gate is of the same type, + // and parent's results and current gate's inputs are in the same order + // If OpType is quantum.custom, also verify that parent gate has the + // same gate name. + + if (!parentOp || !isa(parentOp)){ + return false; + } + + if (isa(op)){ + StringRef OpGateName = cast(op).getGateName(); + StringRef ParentGateName = cast(parentOp).getGateName(); + if (OpGateName != ParentGateName){ + return false; + } + } + + ValueRange InQubits = op.getInQubits(); + ValueRange ParentOutQubits = parentOp.getOutQubits(); + for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { + if (Qubit.getDefiningOp() != parentOp || Qubit != ParentOutQubits[Idx]){ + return false; + } + } + + return true; + } + + + bool verifyParentGateParams(OpType op, OpType parentOp) const { + // Verify that the parent gate has the same parameters + return true; + } + + /// Remove generic neighbouring gate pairs of the form + /// --- gate --- gate{adjoint} --- + /// Conditions: + /// 1. Both gates must be of the same type, i.e. a quantum.custom can + /// only be cancelled with a quantum.custom, not a quantum.unitary + /// 2. The results of the parent gate must map one-to-one, in order, + /// to the operands of the second gate + /// 3. If there are parameters, both gate must have the same parameters. + /// 4. If the gates are controlled, both gates' control wires and values + /// must be the same. The control wires must be in the same order + mlir::LogicalResult matchAndRewrite(OpType op, mlir::PatternRewriter &rewriter) const override + { + LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); + + llvm::errs() << "visiting " << op << "\n"; + + ValueRange InQubits = op.getInQubits(); + auto parentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); + + if (!verifyParentGateType(op, parentOp)){ + return failure(); + } + + + + llvm::errs() << "matched!\n"; + ValueRange simplifiedVal = parentOp.getInQubits(); + rewriter.replaceOp(op, simplifiedVal); + return success(); + } + + +}; + } // namespace namespace catalyst { @@ -69,6 +145,8 @@ namespace quantum { void populateSelfInversePatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext(), 1); + patterns.add>(patterns.getContext(), 1); + patterns.add>(patterns.getContext(), 1); } } // namespace quantum diff --git a/mlir/test/Quantum/ChainedSelfInverseTest.mlir b/mlir/test/Quantum/ChainedSelfInverseTest.mlir index c4866ddc0e..36b50c89db 100644 --- a/mlir/test/Quantum/ChainedSelfInverseTest.mlir +++ b/mlir/test/Quantum/ChainedSelfInverseTest.mlir @@ -269,3 +269,21 @@ func.func @test_chained_self_inverse() -> !quantum.bit { return %4 : !quantum.bit } + +// ----- + +// test parametrized gates labeled with adjoint attribute +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse(%arg0: tensor<2x2xf64>) -> !quantum.bit { + // CHECK: quantum.alloc + // CHECK: quantum.extract + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + %2 = stablehlo.convert %arg0 : (tensor<2x2xf64>) -> tensor<2x2xcomplex> + %out_qubits = quantum.unitary(%2 : tensor<2x2xcomplex>) %1 : !quantum.bit + %3 = stablehlo.convert %arg0 : (tensor<2x2xf64>) -> tensor<2x2xcomplex> + %out_qubits_1 = quantum.unitary(%3 : tensor<2x2xcomplex>) %out_qubits {adjoint} : !quantum.bit + + return %out_qubits_1 : !quantum.bit +} From 1458937c288cc17dd2baacf3f410da4890fa2f13 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 7 Oct 2024 14:47:10 -0400 Subject: [PATCH 10/45] preprocess with cse pass so we can check param SSA values; check the pair has exactly one adjoint --- .../Transforms/ChainedSelfInversePatterns.cpp | 36 +++++++++++++++++-- .../remove_chained_self_inverse.cpp | 9 +++++ mlir/test/Quantum/ChainedSelfInverseTest.mlir | 10 ++++-- 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index f18b1d3398..4bca84c230 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -79,9 +79,9 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { } if (isa(op)){ - StringRef OpGateName = cast(op).getGateName(); - StringRef ParentGateName = cast(parentOp).getGateName(); - if (OpGateName != ParentGateName){ + StringRef opGateName = cast(op).getGateName(); + StringRef parentGateName = cast(parentOp).getGateName(); + if (opGateName != parentGateName){ return false; } } @@ -100,9 +100,30 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { bool verifyParentGateParams(OpType op, OpType parentOp) const { // Verify that the parent gate has the same parameters + + ValueRange opParams = op.getAllParams(); + ValueRange parentOpParams = parentOp.getAllParams(); + + if (opParams.size() != parentOpParams.size()){ + return false; + } + + for (auto [opParam, parentOpParam] : llvm::zip(opParams, parentOpParams)){ + if (opParam != parentOpParam){ + return false; + } + } + return true; } + bool verifyOneAdjoint(OpType op, OpType parentOp) const { + // Verify that exactly one of the neighbouring pair is an adjoint + bool opIsAdj = op->hasAttr("adjoint"); + bool parentIsAdj = parentOp->hasAttr("adjoint"); + return opIsAdj != parentIsAdj; // "XOR" to check just one true + } + /// Remove generic neighbouring gate pairs of the form /// --- gate --- gate{adjoint} --- /// Conditions: @@ -111,6 +132,7 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { /// 2. The results of the parent gate must map one-to-one, in order, /// to the operands of the second gate /// 3. If there are parameters, both gate must have the same parameters. + /// [This pattern assumes the IR is already processed by CSE] /// 4. If the gates are controlled, both gates' control wires and values /// must be the same. The control wires must be in the same order mlir::LogicalResult matchAndRewrite(OpType op, mlir::PatternRewriter &rewriter) const override @@ -127,6 +149,14 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { } + if (!verifyParentGateParams(op, parentOp)){ + return failure(); + } + + + if (!verifyOneAdjoint(op, parentOp)){ + return failure(); + } llvm::errs() << "matched!\n"; ValueRange simplifiedVal = parentOp.getInQubits(); diff --git a/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp b/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp index d510fe9036..36334a2b12 100644 --- a/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp +++ b/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp @@ -24,6 +24,8 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "Catalyst/IR/CatalystDialect.h" @@ -50,6 +52,13 @@ struct RemoveChainedSelfInversePass LLVM_DEBUG(dbgs() << "remove chained self inverse pass" << "\n"); + // Run cse pass before running remove-chained-self-inverse, + // to aid identifiying equivalent SSA values when verifying + // the gates have the same params + MLIRContext *ctx = &getContext(); + auto pm = PassManager::on(ctx); + pm.addPass(mlir::createCSEPass()); + Operation *module = getOperation(); Operation *targetfunc; diff --git a/mlir/test/Quantum/ChainedSelfInverseTest.mlir b/mlir/test/Quantum/ChainedSelfInverseTest.mlir index 36b50c89db..9d901c2a7d 100644 --- a/mlir/test/Quantum/ChainedSelfInverseTest.mlir +++ b/mlir/test/Quantum/ChainedSelfInverseTest.mlir @@ -274,7 +274,7 @@ func.func @test_chained_self_inverse() -> !quantum.bit { // test parametrized gates labeled with adjoint attribute // CHECK-LABEL: test_chained_self_inverse -func.func @test_chained_self_inverse(%arg0: tensor<2x2xf64>) -> !quantum.bit { +func.func @test_chained_self_inverse(%arg0: tensor<2x2xf64>, %arg1: tensor, %arg2: tensor) -> (!quantum.bit, !quantum.bit) { // CHECK: quantum.alloc // CHECK: quantum.extract %0 = quantum.alloc( 1) : !quantum.reg @@ -285,5 +285,11 @@ func.func @test_chained_self_inverse(%arg0: tensor<2x2xf64>) -> !quantum.bit { %3 = stablehlo.convert %arg0 : (tensor<2x2xf64>) -> tensor<2x2xcomplex> %out_qubits_1 = quantum.unitary(%3 : tensor<2x2xcomplex>) %out_qubits {adjoint} : !quantum.bit - return %out_qubits_1 : !quantum.bit + %extracted_3 = tensor.extract %arg1[] : tensor + %out_qubits_4 = quantum.custom "RX"(%extracted_3) %out_qubits_1 : !quantum.bit + %extracted_5 = tensor.extract %arg1[] : tensor + %out_qubits_6 = quantum.custom "RX"(%extracted_5) %out_qubits_4 {adjoint} : !quantum.bit + + + return %out_qubits_1, %out_qubits_6 : !quantum.bit, !quantum.bit } From b8b991327e014b0d091b91b90ba1d177709bb8bc Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 7 Oct 2024 15:10:14 -0400 Subject: [PATCH 11/45] tests --- .../Transforms/ChainedSelfInversePatterns.cpp | 6 ++-- .../remove_chained_self_inverse.cpp | 9 +++-- .../Catalyst/ApplyTransformSequenceTest.mlir | 8 ++--- mlir/test/Quantum/ChainedSelfInverseTest.mlir | 35 ++++++++++++++----- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index 4bca84c230..569e71d293 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -139,7 +139,7 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); - llvm::errs() << "visiting " << op << "\n"; + //llvm::errs() << "visiting " << op << "\n"; ValueRange InQubits = op.getInQubits(); auto parentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); @@ -158,13 +158,11 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { return failure(); } - llvm::errs() << "matched!\n"; + //llvm::errs() << "matched!\n"; ValueRange simplifiedVal = parentOp.getInQubits(); rewriter.replaceOp(op, simplifiedVal); return success(); } - - }; } // namespace diff --git a/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp b/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp index 36334a2b12..4ac2cb2190 100644 --- a/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp +++ b/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp @@ -53,11 +53,14 @@ struct RemoveChainedSelfInversePass << "\n"); // Run cse pass before running remove-chained-self-inverse, - // to aid identifiying equivalent SSA values when verifying + // to aid identifying equivalent SSA values when verifying // the gates have the same params MLIRContext *ctx = &getContext(); - auto pm = PassManager::on(ctx); - pm.addPass(mlir::createCSEPass()); + auto earlyCSEpm = PassManager::on(ctx); + earlyCSEpm.addPass(mlir::createCSEPass()); + if (failed(runPipeline(earlyCSEpm, getOperation()))){ + return signalPassFailure(); + } Operation *module = getOperation(); Operation *targetfunc; diff --git a/mlir/test/Catalyst/ApplyTransformSequenceTest.mlir b/mlir/test/Catalyst/ApplyTransformSequenceTest.mlir index fdc91695d8..6d4a4dc47c 100644 --- a/mlir/test/Catalyst/ApplyTransformSequenceTest.mlir +++ b/mlir/test/Catalyst/ApplyTransformSequenceTest.mlir @@ -29,25 +29,25 @@ module @workflow { } } - func.func private @f(%arg0: tensor) -> tensor { + func.func private @f(%arg0: tensor) -> !quantum.bit { %c_0 = stablehlo.constant dense<0> : tensor %extracted = tensor.extract %c_0[] : tensor %0 = quantum.alloc( 1) : !quantum.reg %1 = quantum.extract %0[%extracted] : !quantum.reg -> !quantum.bit %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit %out_qubits_1 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit - return %arg0 : tensor + return %out_qubits_1 : !quantum.bit } - func.func private @g(%arg0: tensor) -> tensor { + func.func private @g(%arg0: tensor) -> !quantum.bit { %c_0 = stablehlo.constant dense<0> : tensor %extracted = tensor.extract %c_0[] : tensor %0 = quantum.alloc( 1) : !quantum.reg %1 = quantum.extract %0[%extracted] : !quantum.reg -> !quantum.bit %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit %out_qubits_1 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit - return %arg0 : tensor + return %out_qubits_1 : !quantum.bit } } diff --git a/mlir/test/Quantum/ChainedSelfInverseTest.mlir b/mlir/test/Quantum/ChainedSelfInverseTest.mlir index 9d901c2a7d..2eacd1b98e 100644 --- a/mlir/test/Quantum/ChainedSelfInverseTest.mlir +++ b/mlir/test/Quantum/ChainedSelfInverseTest.mlir @@ -272,11 +272,12 @@ func.func @test_chained_self_inverse() -> !quantum.bit { // ----- -// test parametrized gates labeled with adjoint attribute + +// test quantum.unitary labeled with adjoint attribute // CHECK-LABEL: test_chained_self_inverse -func.func @test_chained_self_inverse(%arg0: tensor<2x2xf64>, %arg1: tensor, %arg2: tensor) -> (!quantum.bit, !quantum.bit) { +func.func @test_chained_self_inverse(%arg0: tensor<2x2xf64>, %arg1: tensor) -> !quantum.bit { // CHECK: quantum.alloc - // CHECK: quantum.extract + // CHECK: [[IN:%.+]] = quantum.extract %0 = quantum.alloc( 1) : !quantum.reg %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit @@ -285,11 +286,29 @@ func.func @test_chained_self_inverse(%arg0: tensor<2x2xf64>, %arg1: tensor, %3 = stablehlo.convert %arg0 : (tensor<2x2xf64>) -> tensor<2x2xcomplex> %out_qubits_1 = quantum.unitary(%3 : tensor<2x2xcomplex>) %out_qubits {adjoint} : !quantum.bit - %extracted_3 = tensor.extract %arg1[] : tensor - %out_qubits_4 = quantum.custom "RX"(%extracted_3) %out_qubits_1 : !quantum.bit - %extracted_5 = tensor.extract %arg1[] : tensor - %out_qubits_6 = quantum.custom "RX"(%extracted_5) %out_qubits_4 {adjoint} : !quantum.bit + // CHECK-NOT: quantum.unitary + // CHECK: return [[IN]] + return %out_qubits_1 : !quantum.bit +} - return %out_qubits_1, %out_qubits_6 : !quantum.bit, !quantum.bit +// ----- + + +// test quantum.custom labeled with adjoint attribute +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse(%arg0: tensor) -> !quantum.bit { + // CHECK: quantum.alloc + // CHECK: [[IN:%.+]] = quantum.extract + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + %extracted_0 = tensor.extract %arg0[] : tensor + %out_qubits = quantum.custom "RX"(%extracted_0) %1 : !quantum.bit + %extracted_1 = tensor.extract %arg0[] : tensor + %out_qubits_1 = quantum.custom "RX"(%extracted_1) %out_qubits {adjoint} : !quantum.bit + + // CHECK-NOT: quantum.custom + // CHECK: return [[IN]] + return %out_qubits_1 : !quantum.bit } From 43fcce7cce835de3b00bfed015b66e6b5e7cb446 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 7 Oct 2024 15:12:19 -0400 Subject: [PATCH 12/45] format --- .../Transforms/ChainedSelfInversePatterns.cpp | 42 +++++++++---------- .../remove_chained_self_inverse.cpp | 4 +- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index 569e71d293..889c5f093e 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -41,20 +41,20 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern(InQubits[0].getDefiningOp()); - if (!ParentOp || ParentOp.getGateName() != OpGateName){ + if (!ParentOp || ParentOp.getGateName() != OpGateName) { return failure(); } ValueRange ParentOutQubits = ParentOp.getOutQubits(); // Check if the input qubits to the current operation match the output qubits of the parent. for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { - if (Qubit.getDefiningOp() != ParentOp || Qubit != ParentOutQubits[Idx]){ + if (Qubit.getDefiningOp() != ParentOp || Qubit != ParentOutQubits[Idx]) { return failure(); } } @@ -68,20 +68,21 @@ template struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; - bool verifyParentGateType(OpType op, OpType parentOp) const { + bool verifyParentGateType(OpType op, OpType parentOp) const + { // Verify that the parent gate is of the same type, // and parent's results and current gate's inputs are in the same order // If OpType is quantum.custom, also verify that parent gate has the // same gate name. - if (!parentOp || !isa(parentOp)){ + if (!parentOp || !isa(parentOp)) { return false; } - if (isa(op)){ + if (isa(op)) { StringRef opGateName = cast(op).getGateName(); StringRef parentGateName = cast(parentOp).getGateName(); - if (opGateName != parentGateName){ + if (opGateName != parentGateName) { return false; } } @@ -89,7 +90,7 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { ValueRange InQubits = op.getInQubits(); ValueRange ParentOutQubits = parentOp.getOutQubits(); for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { - if (Qubit.getDefiningOp() != parentOp || Qubit != ParentOutQubits[Idx]){ + if (Qubit.getDefiningOp() != parentOp || Qubit != ParentOutQubits[Idx]) { return false; } } @@ -97,19 +98,19 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { return true; } - - bool verifyParentGateParams(OpType op, OpType parentOp) const { + bool verifyParentGateParams(OpType op, OpType parentOp) const + { // Verify that the parent gate has the same parameters ValueRange opParams = op.getAllParams(); ValueRange parentOpParams = parentOp.getAllParams(); - if (opParams.size() != parentOpParams.size()){ + if (opParams.size() != parentOpParams.size()) { return false; } - for (auto [opParam, parentOpParam] : llvm::zip(opParams, parentOpParams)){ - if (opParam != parentOpParam){ + for (auto [opParam, parentOpParam] : llvm::zip(opParams, parentOpParams)) { + if (opParam != parentOpParam) { return false; } } @@ -117,7 +118,8 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { return true; } - bool verifyOneAdjoint(OpType op, OpType parentOp) const { + bool verifyOneAdjoint(OpType op, OpType parentOp) const + { // Verify that exactly one of the neighbouring pair is an adjoint bool opIsAdj = op->hasAttr("adjoint"); bool parentIsAdj = parentOp->hasAttr("adjoint"); @@ -139,26 +141,24 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); - //llvm::errs() << "visiting " << op << "\n"; + // llvm::errs() << "visiting " << op << "\n"; ValueRange InQubits = op.getInQubits(); auto parentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); - if (!verifyParentGateType(op, parentOp)){ + if (!verifyParentGateType(op, parentOp)) { return failure(); } - - if (!verifyParentGateParams(op, parentOp)){ + if (!verifyParentGateParams(op, parentOp)) { return failure(); } - - if (!verifyOneAdjoint(op, parentOp)){ + if (!verifyOneAdjoint(op, parentOp)) { return failure(); } - //llvm::errs() << "matched!\n"; + // llvm::errs() << "matched!\n"; ValueRange simplifiedVal = parentOp.getInQubits(); rewriter.replaceOp(op, simplifiedVal); return success(); diff --git a/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp b/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp index 4ac2cb2190..5b84674b15 100644 --- a/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp +++ b/mlir/lib/Quantum/Transforms/remove_chained_self_inverse.cpp @@ -25,8 +25,8 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" #include "Catalyst/IR/CatalystDialect.h" #include "Quantum/IR/QuantumOps.h" @@ -58,7 +58,7 @@ struct RemoveChainedSelfInversePass MLIRContext *ctx = &getContext(); auto earlyCSEpm = PassManager::on(ctx); earlyCSEpm.addPass(mlir::createCSEPass()); - if (failed(runPipeline(earlyCSEpm, getOperation()))){ + if (failed(runPipeline(earlyCSEpm, getOperation()))) { return signalPassFailure(); } From 799d96e0c6b018f9f71cf96645be8fc48ede6aa2 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 7 Oct 2024 15:19:28 -0400 Subject: [PATCH 13/45] test with explicit rotation angles --- mlir/test/Quantum/ChainedSelfInverseTest.mlir | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/mlir/test/Quantum/ChainedSelfInverseTest.mlir b/mlir/test/Quantum/ChainedSelfInverseTest.mlir index 2eacd1b98e..b8af8ee27d 100644 --- a/mlir/test/Quantum/ChainedSelfInverseTest.mlir +++ b/mlir/test/Quantum/ChainedSelfInverseTest.mlir @@ -312,3 +312,28 @@ func.func @test_chained_self_inverse(%arg0: tensor) -> !quantum.bit { // CHECK: return [[IN]] return %out_qubits_1 : !quantum.bit } + + +// ----- + + +// test with explicit rotation angles +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse() -> !quantum.bit { + // CHECK: quantum.alloc + // CHECK: [[IN:%.+]] = quantum.extract + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + %cst_0 = stablehlo.constant dense<1.234000e+01> : tensor + %extracted_0 = tensor.extract %cst_0[] : tensor + %out_qubits_0 = quantum.custom "RY"(%extracted_0) %1 {adjoint} : !quantum.bit + + %cst_1 = stablehlo.constant dense<1.234000e+01> : tensor + %extracted_1 = tensor.extract %cst_1[] : tensor + %out_qubits_1 = quantum.custom "RY"(%extracted_1) %out_qubits_0 : !quantum.bit + + // CHECK-NOT: quantum.custom + // CHECK: return [[IN]] + return %out_qubits_1 : !quantum.bit +} From b60519ea8736b3957068d94274fbb45b79b6f573 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 7 Oct 2024 15:24:26 -0400 Subject: [PATCH 14/45] test with different explicit params --- mlir/test/Quantum/ChainedSelfInverseTest.mlir | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mlir/test/Quantum/ChainedSelfInverseTest.mlir b/mlir/test/Quantum/ChainedSelfInverseTest.mlir index b8af8ee27d..07f441b021 100644 --- a/mlir/test/Quantum/ChainedSelfInverseTest.mlir +++ b/mlir/test/Quantum/ChainedSelfInverseTest.mlir @@ -337,3 +337,27 @@ func.func @test_chained_self_inverse() -> !quantum.bit { // CHECK: return [[IN]] return %out_qubits_1 : !quantum.bit } + + +// ----- + + +// test with unmatched explicit rotation angles +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse() -> !quantum.bit { + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + %cst_0 = stablehlo.constant dense<1.234000e+01> : tensor + %extracted_0 = tensor.extract %cst_0[] : tensor + %out_qubits_0 = quantum.custom "RY"(%extracted_0) %1 {adjoint} : !quantum.bit + + %cst_1 = stablehlo.constant dense<5.678000e+01> : tensor + %extracted_1 = tensor.extract %cst_1[] : tensor + %out_qubits_1 = quantum.custom "RY"(%extracted_1) %out_qubits_0 : !quantum.bit + + return %out_qubits_1 : !quantum.bit +} + +// CHECK: quantum.custom "RY"{{.+}}{adjoint} +// CHECK: quantum.custom "RY" From 84c758cbad5646732dbdd2c1572f0330b683e59b Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Mon, 7 Oct 2024 15:54:43 -0400 Subject: [PATCH 15/45] cano test --- mlir/test/Quantum/MergeRotationsTest.mlir | 58 +++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/mlir/test/Quantum/MergeRotationsTest.mlir b/mlir/test/Quantum/MergeRotationsTest.mlir index 7b8c214356..40c21f3f6c 100644 --- a/mlir/test/Quantum/MergeRotationsTest.mlir +++ b/mlir/test/Quantum/MergeRotationsTest.mlir @@ -30,6 +30,22 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64) -> !quantum.bit { // ----- +func.func @test_merge_rotations(%arg0: f64, %arg1: f64) -> !quantum.bit { + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[reg:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[sum:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: [[ret:%.+]] = quantum.custom "PhaseShift"([[sum]]) [[qubit]] : !quantum.bit + // CHECK-NOT: quantum.custom "PhaseShift" + %2 = quantum.custom "PhaseShift"(%arg0) %1 : !quantum.bit + %3 = quantum.custom "PhaseShift"(%arg1) %2 : !quantum.bit + // CHECK: return [[ret]] + return %3 : !quantum.bit +} + +// ----- + func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> !quantum.bit { %0 = quantum.alloc( 1) : !quantum.reg %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit @@ -87,4 +103,46 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum %5:2 = quantum.custom "CRY"(%arg2) %4#1, %4#0 : !quantum.bit, !quantum.bit // CHECK: return [[ret]]#0, [[ret]]#1 return %5#0, %5#1 : !quantum.bit, !quantum.bit +} + +// ----- + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit) { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + + // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[qubit2:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[sum1:%.+]] = arith.addf %arg1, %arg2 : f64 + // CHECK: [[sum2:%.+]] = arith.addf %arg0, [[sum1]] : f64 + // CHECK: [[ret:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[sum2]]) [[qubit1]], [[qubit2]] : !quantum.bit, !quantum.bit + // CHECK-NOT: quantum.custom "CRX" + %3:2 = quantum.custom "ControlledPhaseShift"(%arg0) %1, %2: !quantum.bit, !quantum.bit + %4:2 = quantum.custom "ControlledPhaseShift"(%arg1) %3#0, %3#1 : !quantum.bit, !quantum.bit + %5:2 = quantum.custom "ControlledPhaseShift"(%arg2) %4#0, %4#1 : !quantum.bit, !quantum.bit + // CHECK: return [[ret]]#0, [[ret]]#1 + return %5#0, %5#1 : !quantum.bit, !quantum.bit +} + +// ----- + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit) { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + + // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[qubit2:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[sum:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: [[qubits3:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[sum]]) [[qubit1]], [[qubit2]] : !quantum.bit, !quantum.bit + // CHECK: [[ret:%.+]]:2 = quantum.custom "ControlledPhaseShift"(%arg2) [[qubits3]]#1, [[qubits3]]#0 : !quantum.bit, !quantum.bit + // CHECK-NOT: quantum.custom "ControlledPhaseShift" + %3:2 = quantum.custom "ControlledPhaseShift"(%arg0) %1, %2: !quantum.bit, !quantum.bit + %4:2 = quantum.custom "ControlledPhaseShift"(%arg1) %3#0, %3#1 : !quantum.bit, !quantum.bit + %5:2 = quantum.custom "ControlledPhaseShift"(%arg2) %4#1, %4#0 : !quantum.bit, !quantum.bit + // CHECK: return [[ret]]#0, [[ret]]#1 + return %5#0, %5#1 : !quantum.bit, !quantum.bit } \ No newline at end of file From e3447efa8c100389d1474d6aa35bf7d77444da27 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Mon, 7 Oct 2024 16:29:14 -0400 Subject: [PATCH 16/45] Update --- .../Transforms/MergeRotationsPatterns.cpp | 46 +++++++++-------- mlir/test/Quantum/MergeRotationsTest.mlir | 49 ++++++++++++++++++- 2 files changed, 74 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index 5a367057b4..673d8dd1c9 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -20,12 +20,13 @@ #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Errc.h" + using llvm::dbgs; using namespace mlir; -using namespace catalyst; using namespace catalyst::quantum; -static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "CRX", "CRY", "CRZ"}; +static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "PhaseShift", "Rot", + "CRX", "CRY", "CRZ", "ControlledPhaseShift", "CRot"}; namespace { @@ -37,10 +38,12 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); auto loc = op.getLoc(); StringRef OpGateName = op.getGateName(); + if (!rotationsSet.contains(OpGateName)) return failure(); ValueRange InQubits = op.getInQubits(); auto parentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); + if (!parentOp || parentOp.getGateName() != OpGateName) return failure(); @@ -51,25 +54,28 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { return failure(); } + TypeRange OutQubitsTypes = op.getOutQubits().getTypes(); + TypeRange OutQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); + ValueRange parentInQubits = parentOp.getInQubits(); + ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); + ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); + + // One param rot case + auto parentParams = parentOp.getParams(); + auto params = op.getParams(); + std::vector sumParams; + for (auto [param, parentParam] : llvm::zip(params, parentParams)) { + Value sumParam = rewriter.create(loc, parentParam, param).getResult(); + sumParams.push_back(sumParam); + }; + auto mergeOp = rewriter.create(loc, OutQubitsTypes, OutQubitsCtrlTypes, sumParams, + parentInQubits, OpGateName, nullptr, + parentInCtrlQubits, parentInCtrlValues); + + op.replaceAllUsesWith(mergeOp); + op.erase(); + parentOp.erase(); - static const mlir::StringSet<> rotationsSetCase1 = {"RX", "RY", "RZ", "CRX", "CRY", "CRZ"}; - - if (rotationsSetCase1.find(OpGateName) != rotationsSetCase1.end()) { - TypeRange OutQubitsTypes = op.getOutQubits().getTypes(); - TypeRange OutQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); - auto parentParams = parentOp.getParams().front(); - auto params = op.getParams().front(); - Value sumParams = rewriter.create(loc, parentParams, params).getResult(); - ValueRange parentInQubits = parentOp.getInQubits(); - ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); - ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); - auto mergeOp = rewriter.create(loc, OutQubitsTypes, OutQubitsCtrlTypes, - sumParams, parentInQubits, OpGateName, nullptr, - parentInCtrlQubits, parentInCtrlValues); - op.replaceAllUsesWith(mergeOp); - op.erase(); - parentOp.erase(); - } return success(); } }; diff --git a/mlir/test/Quantum/MergeRotationsTest.mlir b/mlir/test/Quantum/MergeRotationsTest.mlir index 40c21f3f6c..ca5d1b1344 100644 --- a/mlir/test/Quantum/MergeRotationsTest.mlir +++ b/mlir/test/Quantum/MergeRotationsTest.mlir @@ -145,4 +145,51 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum %5:2 = quantum.custom "ControlledPhaseShift"(%arg2) %4#1, %4#0 : !quantum.bit, !quantum.bit // CHECK: return [[ret]]#0, [[ret]]#1 return %5#0, %5#1 : !quantum.bit, !quantum.bit -} \ No newline at end of file +} + +// ----- + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> !quantum.bit { + // CHECK: [[reg:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[angle00:%.+]] = arith.addf %arg1, %arg2 : f64 + // CHECK: [[angle10:%.+]] = arith.addf %arg2, %arg0 : f64 + // CHECK: [[angle20:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: [[angle01:%.+]] = arith.addf %arg0, [[angle00]] : f64 + // CHECK: [[angle11:%.+]] = arith.addf %arg1, [[angle10]] : f64 + // CHECK: [[angle21:%.+]] = arith.addf %arg2, [[angle20]] : f64 + // CHECK: [[ret:%.+]] = quantum.custom "Rot"([[angle01]], [[angle11]], [[angle21]]) [[qubit]] : !quantum.bit + // CHECK-NOT: quantum.custom "Rot" + %2 = quantum.custom "Rot"(%arg0, %arg1, %arg2) %1 : !quantum.bit + %3 = quantum.custom "Rot"(%arg1, %arg2, %arg0) %2 : !quantum.bit + %4 = quantum.custom "Rot"(%arg2, %arg0, %arg1) %3 : !quantum.bit + // CHECK: return [[ret]] + return %4 : !quantum.bit +} + + +// ----- + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[qubit2:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[angle00:%.+]] = arith.addf %arg1, %arg2 : f64 + // CHECK: [[angle10:%.+]] = arith.addf %arg2, %arg0 : f64 + // CHECK: [[angle20:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: [[angle01:%.+]] = arith.addf %arg0, [[angle00]] : f64 + // CHECK: [[angle11:%.+]] = arith.addf %arg1, [[angle10]] : f64 + // CHECK: [[angle21:%.+]] = arith.addf %arg2, [[angle20]] : f64 + // CHECK: [[ret:%.+]]:2 = quantum.custom "CRot"([[angle01]], [[angle11]], [[angle21]]) [[qubit1]], [[qubit2]] : !quantum.bit + // CHECK-NOT: quantum.custom "CRot" + %3:2 = quantum.custom "CRot"(%arg0, %arg1, %arg2) %1, %2 : !quantum.bit, !quantum.bit + %4:2 = quantum.custom "CRot"(%arg1, %arg2, %arg0) %3#0, %3#1 : !quantum.bit, !quantum.bit + %5:2 = quantum.custom "CRot"(%arg2, %arg0, %arg1) %4#0, %4#1 : !quantum.bit, !quantum.bit + // CHECK: return [[ret]]#0, [[ret]]#1 + return %5#0, %5#1 : !quantum.bit, !quantum.bit +} From a43d9ac51f94f7163d10af2c0a9bbac53faf2b34 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 7 Oct 2024 16:54:55 -0400 Subject: [PATCH 17/45] changelog --- doc/releases/changelog-dev.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index e4168d3172..0594998c63 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -208,6 +208,8 @@ * Samples on lightning.qubit/kokkos can now be seeded with `qjit(seed=...)`. [(#1164)](https://github.com/PennyLaneAI/catalyst/pull/1164) +* The compiler pass `-remove-chained-self-inverse` can now also cancel adjoints of arbitrary unitaries (on top of just the named Hermitian gates). + [(#1186)](https://github.com/PennyLaneAI/catalyst/pull/1186)

Breaking changes

From 05f5a5bfa8cb530963391ad826b875efe13eb464 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Mon, 7 Oct 2024 17:19:55 -0400 Subject: [PATCH 18/45] Initial draft multiRZ --- .../Transforms/MergeRotationsPatterns.cpp | 74 ++++++++++++++++--- 1 file changed, 62 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index 673d8dd1c9..f37be8cc61 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -37,25 +37,25 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); auto loc = op.getLoc(); - StringRef OpGateName = op.getGateName(); + StringRef opGateName = op.getGateName(); - if (!rotationsSet.contains(OpGateName)) + if (!rotationsSet.contains(opGateName)) return failure(); - ValueRange InQubits = op.getInQubits(); - auto parentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); + ValueRange inQubits = op.getInQubits(); + auto parentOp = dyn_cast_or_null(inQubits[0].getDefiningOp()); - if (!parentOp || parentOp.getGateName() != OpGateName) + if (!parentOp || parentOp.getGateName() != opGateName) return failure(); - ValueRange ParentOutQubits = parentOp.getOutQubits(); + ValueRange parentOutQubits = parentOp.getOutQubits(); // Check if the input qubits to the current operation match the output qubits of the parent. - for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { - if (Qubit.getDefiningOp() != parentOp || Qubit != ParentOutQubits[Idx]) + for (const auto &[Idx, Qubit] : llvm::enumerate(inQubits)) { + if (Qubit.getDefiningOp() != parentOp || Qubit != parentOutQubits[Idx]) return failure(); } - TypeRange OutQubitsTypes = op.getOutQubits().getTypes(); - TypeRange OutQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); + TypeRange outQubitsTypes = op.getOutQubits().getTypes(); + TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); ValueRange parentInQubits = parentOp.getInQubits(); ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); @@ -68,8 +68,8 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { Value sumParam = rewriter.create(loc, parentParam, param).getResult(); sumParams.push_back(sumParam); }; - auto mergeOp = rewriter.create(loc, OutQubitsTypes, OutQubitsCtrlTypes, sumParams, - parentInQubits, OpGateName, nullptr, + auto mergeOp = rewriter.create(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams, + parentInQubits, opGateName, nullptr, parentInCtrlQubits, parentInCtrlValues); op.replaceAllUsesWith(mergeOp); @@ -80,6 +80,55 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { } }; +struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(MultiRZOp op, + mlir::PatternRewriter &rewriter) const override + { + LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); + auto loc = op.getLoc(); + ValueRange InQubits = op.getInQubits(); + + // Check parent op + auto parentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); + + if (!parentOp) + return failure(); + + // Check the target qubit + ValueRange parentOutQubits = parentOp.getOutQubits(); + for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { + if (Qubit.getDefiningOp() != parentOp || Qubit != parentOutQubits[Idx]) + return failure(); + } + + // Check the control qubits + ValueRange inCtrlQubits = op.getInCtrlQubits(); + ValueRange parentOutCtrlQubits = parentOp.getOutCtrlQubits(); + for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { + if (Qubit.getDefiningOp() != parentOp || Qubit != parentOutCtrlQubits[Idx]) + return failure(); + } + // Check the control values + + // ... + + // Sum the angles control values + + // Replace operation + TypeRange outQubitsTypes = op.getOutQubits().getTypes(); + TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); + ValueRange parentInQubits = parentOp.getInQubits(); + ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); + ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); + // op.replaceAllUsesWith(mergeOp); + // op.erase(); + // parentOp.erase(); + + return success(); + } +}; } // namespace namespace catalyst { @@ -88,6 +137,7 @@ namespace quantum { void populateMergeRotationsPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext(), 1); + patterns.add(patterns.getContext(), 1); } } // namespace quantum From bc5faafb831b41242958cf5c8aab1303adbbb1a6 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Mon, 7 Oct 2024 17:23:39 -0400 Subject: [PATCH 19/45] Typo --- mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index f37be8cc61..ecdf776593 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -60,7 +60,6 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); - // One param rot case auto parentParams = parentOp.getParams(); auto params = op.getParams(); std::vector sumParams; From 485c6d6961a1e8073e15f32ea3b54d3d27ef0fb9 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Tue, 8 Oct 2024 14:57:14 -0400 Subject: [PATCH 20/45] ctrl gates --- .../Transforms/ChainedSelfInversePatterns.cpp | 64 +++++++++++++++---- mlir/test/Quantum/ChainedSelfInverseTest.mlir | 31 +++++++++ 2 files changed, 82 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index 889c5f093e..81c9db7a7d 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -14,11 +14,13 @@ #define DEBUG_TYPE "chained-self-inverse" -#include "Quantum/IR/QuantumOps.h" -#include "Quantum/Transforms/Patterns.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Errc.h" + +#include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/Patterns.h" + using llvm::dbgs; using namespace mlir; using namespace catalyst; @@ -70,8 +72,7 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { bool verifyParentGateType(OpType op, OpType parentOp) const { - // Verify that the parent gate is of the same type, - // and parent's results and current gate's inputs are in the same order + // Verify that the parent gate is of the same type. // If OpType is quantum.custom, also verify that parent gate has the // same gate name. @@ -87,10 +88,39 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { } } - ValueRange InQubits = op.getInQubits(); - ValueRange ParentOutQubits = parentOp.getOutQubits(); - for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { - if (Qubit.getDefiningOp() != parentOp || Qubit != ParentOutQubits[Idx]) { + return true; + } + + bool verifyAllInQubits(OpType op, OpType parentOp) const + { + // Verify that parent's results and current gate's inputs are in the same order + // If the gates are controlled, both gates' control wires and values + // must be the same. The control wires must be in the same order. + + ValueRange inNonCtrlQubits = op.getNonCtrlQubitOperands(); + ValueRange inCtrlQubits = op.getCtrlQubitOperands(); + ValueRange parentOutNonCtrlQubits = parentOp.getNonCtrlQubitResults(); + ValueRange parentOutCtrlQubits = parentOp.getCtrlQubitResults(); + + if ((inNonCtrlQubits.size() != parentOutNonCtrlQubits.size()) || + (inCtrlQubits.size() != parentOutCtrlQubits.size())) { + return false; + } + + for (const auto &[idx, qubit] : llvm::enumerate(inNonCtrlQubits)) { + if (qubit.getDefiningOp() != parentOp || qubit != parentOutNonCtrlQubits[idx]) { + return false; + } + } + + ValueRange opCtrlValues = op.getCtrlValueOperands(); + ValueRange parentCtrlValues = parentOp.getCtrlValueOperands(); + if (opCtrlValues.size() != parentCtrlValues.size()) { + return false; + } + + for (const auto &[idx, v] : llvm::enumerate(opCtrlValues)) { + if (v != parentCtrlValues[idx]) { return false; } } @@ -141,8 +171,6 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); - // llvm::errs() << "visiting " << op << "\n"; - ValueRange InQubits = op.getInQubits(); auto parentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); @@ -150,6 +178,10 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { return failure(); } + if (!verifyAllInQubits(op, parentOp)) { + return failure(); + } + if (!verifyParentGateParams(op, parentOp)) { return failure(); } @@ -158,9 +190,15 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { return failure(); } - // llvm::errs() << "matched!\n"; - ValueRange simplifiedVal = parentOp.getInQubits(); - rewriter.replaceOp(op, simplifiedVal); + // Replace uses + ValueRange originalNonCtrlQubits = parentOp.getNonCtrlQubitOperands(); + ValueRange originalCtrlQubits = parentOp.getCtrlQubitOperands(); + for (const auto &[idx, nonCtrlQubitResult] : llvm::enumerate(op.getNonCtrlQubitResults())) { + nonCtrlQubitResult.replaceAllUsesWith(originalNonCtrlQubits[idx]); + } + for (const auto &[idx, ctrlQubitResult] : llvm::enumerate(op.getCtrlQubitResults())) { + ctrlQubitResult.replaceAllUsesWith(originalCtrlQubits[idx]); + } return success(); } }; diff --git a/mlir/test/Quantum/ChainedSelfInverseTest.mlir b/mlir/test/Quantum/ChainedSelfInverseTest.mlir index 07f441b021..f8f5fa5c5d 100644 --- a/mlir/test/Quantum/ChainedSelfInverseTest.mlir +++ b/mlir/test/Quantum/ChainedSelfInverseTest.mlir @@ -361,3 +361,34 @@ func.func @test_chained_self_inverse() -> !quantum.bit { // CHECK: quantum.custom "RY"{{.+}}{adjoint} // CHECK: quantum.custom "RY" + + +// ----- + + +// test with matched control wires +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse() -> (!quantum.bit, !quantum.bit, !quantum.bit) { + %true = llvm.mlir.constant (1 : i1) :i1 + %false = llvm.mlir.constant (0 : i1) :i1 + %cst = llvm.mlir.constant (6.000000e-01 : f64) : f64 + %cst_0 = llvm.mlir.constant (9.000000e-01 : f64) : f64 + %cst_1 = llvm.mlir.constant (3.000000e-01 : f64) : f64 + + // CHECK: quantum.alloc + // CHECK: [[IN0:%.+]] = quantum.extract {{.+}}[ 0] + // CHECK: [[IN1:%.+]] = quantum.extract {{.+}}[ 1] + // CHECK: [[IN2:%.+]] = quantum.extract {{.+}}[ 2] + %reg = quantum.alloc( 3) : !quantum.reg + %0 = quantum.extract %reg[ 0] : !quantum.reg -> !quantum.bit + %1 = quantum.extract %reg[ 1] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %reg[ 2] : !quantum.reg -> !quantum.bit + + %out_qubits:2, %out_ctrl_qubits = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %0, %1 ctrls(%2) ctrlvals(%true) : !quantum.bit, !quantum.bit ctrls !quantum.bit + %out_qubits_1:2, %out_ctrl_qubits_1 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %out_qubits#0, %out_qubits#1 {adjoint} ctrls(%out_ctrl_qubits) ctrlvals(%true) : !quantum.bit, !quantum.bit ctrls !quantum.bit + + + // CHECK-NOT: quantum.custom + // CHECK: return [[IN0]], [[IN1]], [[IN2]] + return %out_qubits_1#0, %out_qubits_1#1, %out_ctrl_qubits_1 : !quantum.bit, !quantum.bit, !quantum.bit +} From 90ddcf1f190a0a54d3f2db23451e61b098872f2e Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Tue, 8 Oct 2024 15:15:27 -0400 Subject: [PATCH 21/45] remove template type in parent getter (a value will have just one definition) --- mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index 81c9db7a7d..e245aced06 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -108,7 +108,7 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { } for (const auto &[idx, qubit] : llvm::enumerate(inNonCtrlQubits)) { - if (qubit.getDefiningOp() != parentOp || qubit != parentOutNonCtrlQubits[idx]) { + if (qubit.getDefiningOp() != parentOp || qubit != parentOutNonCtrlQubits[idx]) { return false; } } From b9acfe69e7a5a9991f029b6df8d51de073aa142b Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Tue, 8 Oct 2024 17:25:12 -0400 Subject: [PATCH 22/45] factor out a parent gate verifier analysis, so it can be reused with merge_rotations, etc. --- .../Transforms/ChainedSelfInversePatterns.cpp | 81 ++--------- .../Transforms/VerifyParentGateAnalysis.hpp | 133 ++++++++++++++++++ 2 files changed, 141 insertions(+), 73 deletions(-) create mode 100644 mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index e245aced06..06bb16fda7 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -14,6 +14,8 @@ #define DEBUG_TYPE "chained-self-inverse" +#include "VerifyParentGateAnalysis.hpp" + #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Errc.h" @@ -70,68 +72,9 @@ template struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; - bool verifyParentGateType(OpType op, OpType parentOp) const - { - // Verify that the parent gate is of the same type. - // If OpType is quantum.custom, also verify that parent gate has the - // same gate name. - - if (!parentOp || !isa(parentOp)) { - return false; - } - - if (isa(op)) { - StringRef opGateName = cast(op).getGateName(); - StringRef parentGateName = cast(parentOp).getGateName(); - if (opGateName != parentGateName) { - return false; - } - } - - return true; - } - - bool verifyAllInQubits(OpType op, OpType parentOp) const - { - // Verify that parent's results and current gate's inputs are in the same order - // If the gates are controlled, both gates' control wires and values - // must be the same. The control wires must be in the same order. - - ValueRange inNonCtrlQubits = op.getNonCtrlQubitOperands(); - ValueRange inCtrlQubits = op.getCtrlQubitOperands(); - ValueRange parentOutNonCtrlQubits = parentOp.getNonCtrlQubitResults(); - ValueRange parentOutCtrlQubits = parentOp.getCtrlQubitResults(); - - if ((inNonCtrlQubits.size() != parentOutNonCtrlQubits.size()) || - (inCtrlQubits.size() != parentOutCtrlQubits.size())) { - return false; - } - - for (const auto &[idx, qubit] : llvm::enumerate(inNonCtrlQubits)) { - if (qubit.getDefiningOp() != parentOp || qubit != parentOutNonCtrlQubits[idx]) { - return false; - } - } - - ValueRange opCtrlValues = op.getCtrlValueOperands(); - ValueRange parentCtrlValues = parentOp.getCtrlValueOperands(); - if (opCtrlValues.size() != parentCtrlValues.size()) { - return false; - } - - for (const auto &[idx, v] : llvm::enumerate(opCtrlValues)) { - if (v != parentCtrlValues[idx]) { - return false; - } - } - - return true; - } - bool verifyParentGateParams(OpType op, OpType parentOp) const { // Verify that the parent gate has the same parameters - ValueRange opParams = op.getAllParams(); ValueRange parentOpParams = parentOp.getAllParams(); @@ -159,28 +102,20 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { /// Remove generic neighbouring gate pairs of the form /// --- gate --- gate{adjoint} --- /// Conditions: - /// 1. Both gates must be of the same type, i.e. a quantum.custom can - /// only be cancelled with a quantum.custom, not a quantum.unitary - /// 2. The results of the parent gate must map one-to-one, in order, - /// to the operands of the second gate - /// 3. If there are parameters, both gate must have the same parameters. + /// 1. Parent gate verification must pass. See VerifyParentGateAnalysis.hpp. + /// 2. If there are parameters, both gate must have the same parameters. /// [This pattern assumes the IR is already processed by CSE] - /// 4. If the gates are controlled, both gates' control wires and values - /// must be the same. The control wires must be in the same order mlir::LogicalResult matchAndRewrite(OpType op, mlir::PatternRewriter &rewriter) const override { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); - ValueRange InQubits = op.getInQubits(); - auto parentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); - - if (!verifyParentGateType(op, parentOp)) { + VerifyParentGateAnalysis vpga(op); + if (!vpga.getVerifierResult()) { return failure(); } - if (!verifyAllInQubits(op, parentOp)) { - return failure(); - } + ValueRange InQubits = op.getInQubits(); + auto parentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); if (!verifyParentGateParams(op, parentOp)) { return failure(); diff --git a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp new file mode 100644 index 0000000000..a770f1a253 --- /dev/null +++ b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp @@ -0,0 +1,133 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This analysis checks if a gate operation and its parent gate operation +// are correctly matched for the purposes of peephole optimizations like +// merge rotations and cancel inverses. +// Gates passing this analysis are considered valid candidates for merge +// rotation and cancel inverses. + +// Specifically, we check the following conditions: +// 1. Both gates must be of the same type, i.e. a quantum.custom can +// only be cancelled with a quantum.custom, not a quantum.unitary +// 2. The results of the parent gate must map one-to-one, in order, +// to the operands of the second gate +// For example, the pair +// %0:2 = quantum.custom "CNOT"() %.., %.. +// %1:2 = quantum.custom "CNOT"() %0#0, %0#1 +// is considered to be a successful match, but the pair +// %0:2 = quantum.custom "CNOT"() %.., %.. +// %1:2 = quantum.custom "CNOT"() %0#1, %0#0 +// is not. +// 3. If the gates are controlled, both gates' control wires and values +// must be the same. The control wires must be in the same order + +#pragma once + +#include "mlir/IR/BuiltinOps.h" +#include "llvm/Support/Debug.h" + +#include "Catalyst/IR/CatalystDialect.h" +#include "Quantum/IR/QuantumOps.h" + +using namespace llvm; +using namespace mlir; +using namespace catalyst; + +namespace catalyst { + +template class VerifyParentGateAnalysis { + public: + VerifyParentGateAnalysis(OpType gate) + { + ValueRange inQubits = gate.getInQubits(); + auto parentGate = dyn_cast_or_null(inQubits[0].getDefiningOp()); + + if (!verifyParentGateType(gate, parentGate)) { + verified = false; + return; + } + + if (!verifyAllInQubits(gate, parentGate)) { + verified = false; + return; + } + } + + bool getVerifierResult() { return verified; } + + private: + bool verified = true; + + bool verifyParentGateType(OpType op, OpType parentOp) const + { + // Verify that the parent gate is of the same type. + // If OpType is quantum.custom, also verify that parent gate has the + // same gate name. + + if (!parentOp || !isa(parentOp)) { + return false; + } + + if (isa(op)) { + StringRef opGateName = cast(op).getGateName(); + StringRef parentGateName = cast(parentOp).getGateName(); + if (opGateName != parentGateName) { + return false; + } + } + + return true; + } + + bool verifyAllInQubits(OpType op, OpType parentOp) const + { + // Verify that parent's results and current gate's inputs are in the same order + // If the gates are controlled, both gates' control wires and values + // must be the same. The control wires must be in the same order. + + ValueRange inNonCtrlQubits = op.getNonCtrlQubitOperands(); + ValueRange inCtrlQubits = op.getCtrlQubitOperands(); + ValueRange parentOutNonCtrlQubits = parentOp.getNonCtrlQubitResults(); + ValueRange parentOutCtrlQubits = parentOp.getCtrlQubitResults(); + + if ((inNonCtrlQubits.size() != parentOutNonCtrlQubits.size()) || + (inCtrlQubits.size() != parentOutCtrlQubits.size())) { + return false; + } + + for (const auto &[idx, qubit] : llvm::enumerate(inNonCtrlQubits)) { + if (qubit.getDefiningOp() != parentOp || qubit != parentOutNonCtrlQubits[idx]) { + return false; + } + } + + ValueRange opCtrlValues = op.getCtrlValueOperands(); + ValueRange parentCtrlValues = parentOp.getCtrlValueOperands(); + if (opCtrlValues.size() != parentCtrlValues.size()) { + return false; + } + + for (const auto &[idx, v] : llvm::enumerate(opCtrlValues)) { + // We assume CSE is already run before this analysis. + if (v != parentCtrlValues[idx]) { + return false; + } + } + + return true; + } +}; + +} // namespace catalyst From 13eb0944462385808dc99b352fca36abcd0f5d44 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Wed, 9 Oct 2024 09:51:43 -0400 Subject: [PATCH 23/45] add all test cases for ctrl --- .../Transforms/VerifyParentGateAnalysis.hpp | 6 + mlir/test/Quantum/ChainedSelfInverseTest.mlir | 145 +++++++++++++++++- 2 files changed, 144 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp index a770f1a253..57009dd060 100644 --- a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp +++ b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp @@ -113,6 +113,12 @@ template class VerifyParentGateAnalysis { } } + for (const auto &[idx, qubit] : llvm::enumerate(inCtrlQubits)) { + if (qubit.getDefiningOp() != parentOp || qubit != parentOutCtrlQubits[idx]) { + return false; + } + } + ValueRange opCtrlValues = op.getCtrlValueOperands(); ValueRange parentCtrlValues = parentOp.getCtrlValueOperands(); if (opCtrlValues.size() != parentCtrlValues.size()) { diff --git a/mlir/test/Quantum/ChainedSelfInverseTest.mlir b/mlir/test/Quantum/ChainedSelfInverseTest.mlir index f8f5fa5c5d..4b082c0936 100644 --- a/mlir/test/Quantum/ChainedSelfInverseTest.mlir +++ b/mlir/test/Quantum/ChainedSelfInverseTest.mlir @@ -274,6 +274,7 @@ func.func @test_chained_self_inverse() -> !quantum.bit { // test quantum.unitary labeled with adjoint attribute + // CHECK-LABEL: test_chained_self_inverse func.func @test_chained_self_inverse(%arg0: tensor<2x2xf64>, %arg1: tensor) -> !quantum.bit { // CHECK: quantum.alloc @@ -296,6 +297,7 @@ func.func @test_chained_self_inverse(%arg0: tensor<2x2xf64>, %arg1: tensor) // test quantum.custom labeled with adjoint attribute + // CHECK-LABEL: test_chained_self_inverse func.func @test_chained_self_inverse(%arg0: tensor) -> !quantum.bit { // CHECK: quantum.alloc @@ -308,8 +310,34 @@ func.func @test_chained_self_inverse(%arg0: tensor) -> !quantum.bit { %extracted_1 = tensor.extract %arg0[] : tensor %out_qubits_1 = quantum.custom "RX"(%extracted_1) %out_qubits {adjoint} : !quantum.bit + + %out_qubits_2 = quantum.custom "RX"(%extracted_0) %out_qubits_1 {adjoint} : !quantum.bit + %out_qubits_3 = quantum.custom "RX"(%extracted_1) %out_qubits_2 : !quantum.bit + // CHECK-NOT: quantum.custom // CHECK: return [[IN]] + return %out_qubits_3 : !quantum.bit +} + + +// ----- + + +// test quantum.custom labeled both with adjoints + +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse(%arg0: tensor) -> !quantum.bit { + // CHECK: quantum.alloc + // CHECK: [[IN:%.+]] = quantum.extract + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + %extracted_0 = tensor.extract %arg0[] : tensor + %out_qubits = quantum.custom "RX"(%extracted_0) %1 {adjoint} : !quantum.bit + %extracted_1 = tensor.extract %arg0[] : tensor + %out_qubits_1 = quantum.custom "RX"(%extracted_1) %out_qubits {adjoint} : !quantum.bit + + // CHECK: quantum.custom return %out_qubits_1 : !quantum.bit } @@ -318,6 +346,7 @@ func.func @test_chained_self_inverse(%arg0: tensor) -> !quantum.bit { // test with explicit rotation angles + // CHECK-LABEL: test_chained_self_inverse func.func @test_chained_self_inverse() -> !quantum.bit { // CHECK: quantum.alloc @@ -343,6 +372,7 @@ func.func @test_chained_self_inverse() -> !quantum.bit { // test with unmatched explicit rotation angles + // CHECK-LABEL: test_chained_self_inverse func.func @test_chained_self_inverse() -> !quantum.bit { %0 = quantum.alloc( 1) : !quantum.reg @@ -367,8 +397,9 @@ func.func @test_chained_self_inverse() -> !quantum.bit { // test with matched control wires + // CHECK-LABEL: test_chained_self_inverse -func.func @test_chained_self_inverse() -> (!quantum.bit, !quantum.bit, !quantum.bit) { +func.func @test_chained_self_inverse() -> (!quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit) { %true = llvm.mlir.constant (1 : i1) :i1 %false = llvm.mlir.constant (0 : i1) :i1 %cst = llvm.mlir.constant (6.000000e-01 : f64) : f64 @@ -379,16 +410,116 @@ func.func @test_chained_self_inverse() -> (!quantum.bit, !quantum.bit, !quantum. // CHECK: [[IN0:%.+]] = quantum.extract {{.+}}[ 0] // CHECK: [[IN1:%.+]] = quantum.extract {{.+}}[ 1] // CHECK: [[IN2:%.+]] = quantum.extract {{.+}}[ 2] - %reg = quantum.alloc( 3) : !quantum.reg + // CHECK: [[IN3:%.+]] = quantum.extract {{.+}}[ 3] + %reg = quantum.alloc( 4) : !quantum.reg %0 = quantum.extract %reg[ 0] : !quantum.reg -> !quantum.bit %1 = quantum.extract %reg[ 1] : !quantum.reg -> !quantum.bit %2 = quantum.extract %reg[ 2] : !quantum.reg -> !quantum.bit + %3 = quantum.extract %reg[ 3] : !quantum.reg -> !quantum.bit - %out_qubits:2, %out_ctrl_qubits = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %0, %1 ctrls(%2) ctrlvals(%true) : !quantum.bit, !quantum.bit ctrls !quantum.bit - %out_qubits_1:2, %out_ctrl_qubits_1 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %out_qubits#0, %out_qubits#1 {adjoint} ctrls(%out_ctrl_qubits) ctrlvals(%true) : !quantum.bit, !quantum.bit ctrls !quantum.bit - + %out_qubits:2, %out_ctrl_qubits:2 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %0, %1 ctrls(%2, %3) ctrlvals(%true, %false) : !quantum.bit, !quantum.bit ctrls !quantum.bit, !quantum.bit + %out_qubits_1:2, %out_ctrl_qubits_1:2 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %out_qubits#0, %out_qubits#1 {adjoint} ctrls(%out_ctrl_qubits#0, %out_ctrl_qubits#1) ctrlvals(%true, %false) : !quantum.bit, !quantum.bit ctrls !quantum.bit, !quantum.bit // CHECK-NOT: quantum.custom - // CHECK: return [[IN0]], [[IN1]], [[IN2]] - return %out_qubits_1#0, %out_qubits_1#1, %out_ctrl_qubits_1 : !quantum.bit, !quantum.bit, !quantum.bit + // CHECK: return [[IN0]], [[IN1]], [[IN2]], [[IN3]] + return %out_qubits_1#0, %out_qubits_1#1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit +} + + +// ----- + + +// test with unmatched operation wires + +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse() -> (!quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit) { + %true = llvm.mlir.constant (1 : i1) :i1 + %false = llvm.mlir.constant (0 : i1) :i1 + %cst = llvm.mlir.constant (6.000000e-01 : f64) : f64 + %cst_0 = llvm.mlir.constant (9.000000e-01 : f64) : f64 + %cst_1 = llvm.mlir.constant (3.000000e-01 : f64) : f64 + + // CHECK: quantum.alloc + // CHECK: [[IN0:%.+]] = quantum.extract {{.+}}[ 0] + // CHECK: [[IN1:%.+]] = quantum.extract {{.+}}[ 1] + // CHECK: [[IN2:%.+]] = quantum.extract {{.+}}[ 2] + // CHECK: [[IN3:%.+]] = quantum.extract {{.+}}[ 3] + %reg = quantum.alloc( 4) : !quantum.reg + %0 = quantum.extract %reg[ 0] : !quantum.reg -> !quantum.bit + %1 = quantum.extract %reg[ 1] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %reg[ 2] : !quantum.reg -> !quantum.bit + %3 = quantum.extract %reg[ 3] : !quantum.reg -> !quantum.bit + + // CHECK: quantum.custom + %out_qubits:2, %out_ctrl_qubits:2 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %0, %1 ctrls(%2, %3) ctrlvals(%true, %false) : !quantum.bit, !quantum.bit ctrls !quantum.bit, !quantum.bit + %out_qubits_1:2, %out_ctrl_qubits_1:2 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %out_qubits#1, %out_qubits#0 {adjoint} ctrls(%out_ctrl_qubits#0, %out_ctrl_qubits#1) ctrlvals(%true, %false) : !quantum.bit, !quantum.bit ctrls !quantum.bit, !quantum.bit + + + return %out_qubits_1#0, %out_qubits_1#1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit +} + + +// ----- + + +// test with unmatched control wires + +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse() -> (!quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit) { + %true = llvm.mlir.constant (1 : i1) :i1 + %false = llvm.mlir.constant (0 : i1) :i1 + %cst = llvm.mlir.constant (6.000000e-01 : f64) : f64 + %cst_0 = llvm.mlir.constant (9.000000e-01 : f64) : f64 + %cst_1 = llvm.mlir.constant (3.000000e-01 : f64) : f64 + + // CHECK: quantum.alloc + // CHECK: [[IN0:%.+]] = quantum.extract {{.+}}[ 0] + // CHECK: [[IN1:%.+]] = quantum.extract {{.+}}[ 1] + // CHECK: [[IN2:%.+]] = quantum.extract {{.+}}[ 2] + // CHECK: [[IN3:%.+]] = quantum.extract {{.+}}[ 3] + %reg = quantum.alloc( 4) : !quantum.reg + %0 = quantum.extract %reg[ 0] : !quantum.reg -> !quantum.bit + %1 = quantum.extract %reg[ 1] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %reg[ 2] : !quantum.reg -> !quantum.bit + %3 = quantum.extract %reg[ 3] : !quantum.reg -> !quantum.bit + + // CHECK: quantum.custom + %out_qubits:2, %out_ctrl_qubits:2 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %0, %1 ctrls(%2, %3) ctrlvals(%true, %false) : !quantum.bit, !quantum.bit ctrls !quantum.bit, !quantum.bit + %out_qubits_1:2, %out_ctrl_qubits_1:2 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %out_qubits#0, %out_qubits#1 {adjoint} ctrls(%out_ctrl_qubits#1, %out_ctrl_qubits#0) ctrlvals(%true, %false) : !quantum.bit, !quantum.bit ctrls !quantum.bit, !quantum.bit + + + return %out_qubits_1#0, %out_qubits_1#1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit +} + + +// ----- + + +// test with unmatched control values + +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse() -> (!quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit) { + %true = llvm.mlir.constant (1 : i1) :i1 + %false = llvm.mlir.constant (0 : i1) :i1 + %cst = llvm.mlir.constant (6.000000e-01 : f64) : f64 + %cst_0 = llvm.mlir.constant (9.000000e-01 : f64) : f64 + %cst_1 = llvm.mlir.constant (3.000000e-01 : f64) : f64 + + // CHECK: quantum.alloc + // CHECK: [[IN0:%.+]] = quantum.extract {{.+}}[ 0] + // CHECK: [[IN1:%.+]] = quantum.extract {{.+}}[ 1] + // CHECK: [[IN2:%.+]] = quantum.extract {{.+}}[ 2] + // CHECK: [[IN3:%.+]] = quantum.extract {{.+}}[ 3] + %reg = quantum.alloc( 4) : !quantum.reg + %0 = quantum.extract %reg[ 0] : !quantum.reg -> !quantum.bit + %1 = quantum.extract %reg[ 1] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %reg[ 2] : !quantum.reg -> !quantum.bit + %3 = quantum.extract %reg[ 3] : !quantum.reg -> !quantum.bit + + // CHECK: quantum.custom + %out_qubits:2, %out_ctrl_qubits:2 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %0, %1 ctrls(%2, %3) ctrlvals(%true, %false) : !quantum.bit, !quantum.bit ctrls !quantum.bit, !quantum.bit + %out_qubits_1:2, %out_ctrl_qubits_1:2 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %out_qubits#0, %out_qubits#1 {adjoint} ctrls(%out_ctrl_qubits#0, %out_ctrl_qubits#1) ctrlvals(%false, %true) : !quantum.bit, !quantum.bit ctrls !quantum.bit, !quantum.bit + + + return %out_qubits_1#0, %out_qubits_1#1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit } From ebb5169695448645dff4498abb11883b7f8fecc4 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Wed, 9 Oct 2024 11:19:30 -0400 Subject: [PATCH 24/45] make the named hermitian pattern use the common analysis as well --- .../Transforms/ChainedSelfInversePatterns.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index 06bb16fda7..bc5aeacde1 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -44,6 +44,11 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern vpga(op); + if (!vpga.getVerifierResult()) { + return failure(); + } + StringRef OpGateName = op.getGateName(); if (!HermitianOps.contains(OpGateName)) { return failure(); @@ -51,17 +56,11 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern(InQubits[0].getDefiningOp()); - if (!ParentOp || ParentOp.getGateName() != OpGateName) { + if (ParentOp.getGateName() != OpGateName) { return failure(); } - ValueRange ParentOutQubits = ParentOp.getOutQubits(); - // Check if the input qubits to the current operation match the output qubits of the parent. - for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { - if (Qubit.getDefiningOp() != ParentOp || Qubit != ParentOutQubits[Idx]) { - return failure(); - } - } + // Replace uses ValueRange simplifiedVal = ParentOp.getInQubits(); rewriter.replaceOp(op, simplifiedVal); return success(); From 5c77fb381c0a9fdf4bd7dccf1809e4ccab45c00a Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Wed, 9 Oct 2024 11:25:57 -0400 Subject: [PATCH 25/45] one more test --- mlir/test/Quantum/ChainedSelfInverseTest.mlir | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/mlir/test/Quantum/ChainedSelfInverseTest.mlir b/mlir/test/Quantum/ChainedSelfInverseTest.mlir index 4b082c0936..a5d03177db 100644 --- a/mlir/test/Quantum/ChainedSelfInverseTest.mlir +++ b/mlir/test/Quantum/ChainedSelfInverseTest.mlir @@ -521,5 +521,38 @@ func.func @test_chained_self_inverse() -> (!quantum.bit, !quantum.bit, !quantum. %out_qubits_1:2, %out_ctrl_qubits_1:2 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %out_qubits#0, %out_qubits#1 {adjoint} ctrls(%out_ctrl_qubits#0, %out_ctrl_qubits#1) ctrlvals(%false, %true) : !quantum.bit, !quantum.bit ctrls !quantum.bit, !quantum.bit + return %out_qubits_1#0, %out_qubits_1#1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit +} + + +// ----- + + +// test with params in the wrong order + +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse() -> (!quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit) { + %true = llvm.mlir.constant (1 : i1) :i1 + %false = llvm.mlir.constant (0 : i1) :i1 + %cst = llvm.mlir.constant (6.000000e-01 : f64) : f64 + %cst_0 = llvm.mlir.constant (9.000000e-01 : f64) : f64 + %cst_1 = llvm.mlir.constant (3.000000e-01 : f64) : f64 + + // CHECK: quantum.alloc + // CHECK: [[IN0:%.+]] = quantum.extract {{.+}}[ 0] + // CHECK: [[IN1:%.+]] = quantum.extract {{.+}}[ 1] + // CHECK: [[IN2:%.+]] = quantum.extract {{.+}}[ 2] + // CHECK: [[IN3:%.+]] = quantum.extract {{.+}}[ 3] + %reg = quantum.alloc( 4) : !quantum.reg + %0 = quantum.extract %reg[ 0] : !quantum.reg -> !quantum.bit + %1 = quantum.extract %reg[ 1] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %reg[ 2] : !quantum.reg -> !quantum.bit + %3 = quantum.extract %reg[ 3] : !quantum.reg -> !quantum.bit + + // CHECK: quantum.custom + %out_qubits:2, %out_ctrl_qubits:2 = quantum.custom "Rot"(%cst, %cst_0, %cst_1) %0, %1 ctrls(%2, %3) ctrlvals(%true, %false) : !quantum.bit, !quantum.bit ctrls !quantum.bit, !quantum.bit + %out_qubits_1:2, %out_ctrl_qubits_1:2 = quantum.custom "Rot"(%cst_0, %cst, %cst_1) %out_qubits#0, %out_qubits#1 {adjoint} ctrls(%out_ctrl_qubits#0, %out_ctrl_qubits#1) ctrlvals(%true, %false) : !quantum.bit, !quantum.bit ctrls !quantum.bit, !quantum.bit + + return %out_qubits_1#0, %out_qubits_1#1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit } From 9ae9ba71aace85da3b284a474c03a582655c17ee Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Wed, 9 Oct 2024 15:28:06 -0400 Subject: [PATCH 26/45] follow include order guideline --- mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index bc5aeacde1..75f765be79 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -14,8 +14,6 @@ #define DEBUG_TYPE "chained-self-inverse" -#include "VerifyParentGateAnalysis.hpp" - #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Errc.h" @@ -23,6 +21,8 @@ #include "Quantum/IR/QuantumOps.h" #include "Quantum/Transforms/Patterns.h" +#include "VerifyParentGateAnalysis.hpp" + using llvm::dbgs; using namespace mlir; using namespace catalyst; From 738c96a61c6127b3c54c9568a0f159b8f1548f0b Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Wed, 9 Oct 2024 15:32:02 -0400 Subject: [PATCH 27/45] `verified` --> `succeeded` --- mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp index 57009dd060..bf3edb99b4 100644 --- a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp +++ b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp @@ -55,20 +55,20 @@ template class VerifyParentGateAnalysis { auto parentGate = dyn_cast_or_null(inQubits[0].getDefiningOp()); if (!verifyParentGateType(gate, parentGate)) { - verified = false; + succeeded = false; return; } if (!verifyAllInQubits(gate, parentGate)) { - verified = false; + succeeded = false; return; } } - bool getVerifierResult() { return verified; } + bool getVerifierResult() { return succeeded; } private: - bool verified = true; + bool succeeded = true; bool verifyParentGateType(OpType op, OpType parentOp) const { From e78cca114932bcda0cfb900101df9fbb27c8551c Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Thu, 10 Oct 2024 09:06:54 -0400 Subject: [PATCH 28/45] move namecheck before wire verification --- .../lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index 75f765be79..787a77e81a 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -44,13 +44,13 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern vpga(op); - if (!vpga.getVerifierResult()) { + StringRef OpGateName = op.getGateName(); + if (!HermitianOps.contains(OpGateName)) { return failure(); } - StringRef OpGateName = op.getGateName(); - if (!HermitianOps.contains(OpGateName)) { + VerifyParentGateAnalysis vpga(op); + if (!vpga.getVerifierResult()) { return failure(); } From 6324fd087893ccd5b786c3c8c3100639e5688a6d Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Thu, 10 Oct 2024 09:43:29 -0400 Subject: [PATCH 29/45] Add analysis integration --- .../Transforms/MergeRotationsPatterns.cpp | 52 ++++--------------- 1 file changed, 10 insertions(+), 42 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index ecdf776593..38f3a9d4e7 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -16,6 +16,7 @@ #include "Quantum/IR/QuantumOps.h" #include "Quantum/Transforms/Patterns.h" +#include "VerifyParentGateAnalysis.hpp" #include "mlir/Dialect/Arith/IR/Arith.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" @@ -47,11 +48,9 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { if (!parentOp || parentOp.getGateName() != opGateName) return failure(); - ValueRange parentOutQubits = parentOp.getOutQubits(); - // Check if the input qubits to the current operation match the output qubits of the parent. - for (const auto &[Idx, Qubit] : llvm::enumerate(inQubits)) { - if (Qubit.getDefiningOp() != parentOp || Qubit != parentOutQubits[Idx]) - return failure(); + VerifyParentGateAnalysis vpga(op); + if (!vpga.getVerifierResult()) { + return failure(); } TypeRange outQubitsTypes = op.getOutQubits().getTypes(); @@ -62,9 +61,9 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { auto parentParams = parentOp.getParams(); auto params = op.getParams(); - std::vector sumParams; + SmallVector sumParams; for (auto [param, parentParam] : llvm::zip(params, parentParams)) { - Value sumParam = rewriter.create(loc, parentParam, param).getResult(); + mlir::Value sumParam = rewriter.create(loc, parentParam, param).getResult(); sumParams.push_back(sumParam); }; auto mergeOp = rewriter.create(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams, @@ -86,45 +85,14 @@ struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); - auto loc = op.getLoc(); - ValueRange InQubits = op.getInQubits(); + // auto loc = op.getLoc(); + // ValueRange InQubits = op.getInQubits(); - // Check parent op - auto parentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); - - if (!parentOp) + VerifyParentGateAnalysis vpga(op); + if (!vpga.getVerifierResult()) { return failure(); - - // Check the target qubit - ValueRange parentOutQubits = parentOp.getOutQubits(); - for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { - if (Qubit.getDefiningOp() != parentOp || Qubit != parentOutQubits[Idx]) - return failure(); } - // Check the control qubits - ValueRange inCtrlQubits = op.getInCtrlQubits(); - ValueRange parentOutCtrlQubits = parentOp.getOutCtrlQubits(); - for (const auto &[Idx, Qubit] : llvm::enumerate(InQubits)) { - if (Qubit.getDefiningOp() != parentOp || Qubit != parentOutCtrlQubits[Idx]) - return failure(); - } - // Check the control values - - // ... - - // Sum the angles control values - - // Replace operation - TypeRange outQubitsTypes = op.getOutQubits().getTypes(); - TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); - ValueRange parentInQubits = parentOp.getInQubits(); - ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); - ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); - // op.replaceAllUsesWith(mergeOp); - // op.erase(); - // parentOp.erase(); - return success(); } }; From a6b8424b9d208ea0636de6466a79ee3ceabd8fa7 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Thu, 10 Oct 2024 09:50:42 -0400 Subject: [PATCH 30/45] MultiRz case --- .../Transforms/MergeRotationsPatterns.cpp | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index 38f3a9d4e7..802d99f57e 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -63,7 +63,8 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { auto params = op.getParams(); SmallVector sumParams; for (auto [param, parentParam] : llvm::zip(params, parentParams)) { - mlir::Value sumParam = rewriter.create(loc, parentParam, param).getResult(); + mlir::Value sumParam = + rewriter.create(loc, parentParam, param).getResult(); sumParams.push_back(sumParam); }; auto mergeOp = rewriter.create(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams, @@ -85,14 +86,36 @@ struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); - // auto loc = op.getLoc(); - // ValueRange InQubits = op.getInQubits(); + auto loc = op.getLoc(); VerifyParentGateAnalysis vpga(op); if (!vpga.getVerifierResult()) { return failure(); } + ValueRange inQubits = op.getInQubits(); + auto parentOp = dyn_cast_or_null(inQubits[0].getDefiningOp()); + if (!parentOp) + return failure(); + + TypeRange outQubitsTypes = op.getOutQubits().getTypes(); + TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); + ValueRange parentInQubits = parentOp.getInQubits(); + ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); + ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); + + auto parentTheta = parentOp.getTheta(); + auto theta = op.getTheta(); + + mlir::Value sumParam = rewriter.create(loc, parentTheta, theta).getResult(); + + auto mergeOp = rewriter.create(loc, outQubitsTypes, outQubitsCtrlTypes, + sumParam, parentInQubits, nullptr, + parentInCtrlQubits, parentInCtrlValues); + op.replaceAllUsesWith(mergeOp); + op.erase(); + parentOp.erase(); + return success(); } }; From b24703a06588aac0149d4f50c2aa6b015e40368b Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Thu, 10 Oct 2024 13:59:32 -0400 Subject: [PATCH 31/45] Split verifier into a "normal" one and an aggressive one. The aggresive verifier checks the two gates have the same name. --- .../Transforms/ChainedSelfInversePatterns.cpp | 8 +-- .../Transforms/VerifyParentGateAnalysis.hpp | 53 +++++++++++++++---- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index 787a77e81a..ababbb7858 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -49,8 +49,8 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern vpga(op); - if (!vpga.getVerifierResult()) { + AggressiveVerifyParentGateAnalysis avpga(op); + if (!avpga.getVerifierResult()) { return failure(); } @@ -108,8 +108,8 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); - VerifyParentGateAnalysis vpga(op); - if (!vpga.getVerifierResult()) { + AggressiveVerifyParentGateAnalysis avpga(op); + if (!avpga.getVerifierResult()) { return failure(); } diff --git a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp index bf3edb99b4..8f6889728e 100644 --- a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp +++ b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp @@ -18,7 +18,7 @@ // Gates passing this analysis are considered valid candidates for merge // rotation and cancel inverses. -// Specifically, we check the following conditions: +// Specifically, we check the following conditions in VerifyParentGateAnalysis: // 1. Both gates must be of the same type, i.e. a quantum.custom can // only be cancelled with a quantum.custom, not a quantum.unitary // 2. The results of the parent gate must map one-to-one, in order, @@ -32,6 +32,10 @@ // is not. // 3. If the gates are controlled, both gates' control wires and values // must be the same. The control wires must be in the same order +// +// On top of the above, we also provide a AggressiveVerifyParentGateAnalysis, +// which also checks: +// 4. If the gates are quantum.custom, then both gates have the same name. #pragma once @@ -55,18 +59,20 @@ template class VerifyParentGateAnalysis { auto parentGate = dyn_cast_or_null(inQubits[0].getDefiningOp()); if (!verifyParentGateType(gate, parentGate)) { - succeeded = false; + setVerifierResult(false); return; } if (!verifyAllInQubits(gate, parentGate)) { - succeeded = false; + setVerifierResult(false); return; } } bool getVerifierResult() { return succeeded; } + void setVerifierResult(bool b) { succeeded = b; } + private: bool succeeded = true; @@ -80,14 +86,6 @@ template class VerifyParentGateAnalysis { return false; } - if (isa(op)) { - StringRef opGateName = cast(op).getGateName(); - StringRef parentGateName = cast(parentOp).getGateName(); - if (opGateName != parentGateName) { - return false; - } - } - return true; } @@ -136,4 +134,37 @@ template class VerifyParentGateAnalysis { } }; +template +class AggressiveVerifyParentGateAnalysis : public VerifyParentGateAnalysis { + public: + AggressiveVerifyParentGateAnalysis(OpType gate) : VerifyParentGateAnalysis(gate) + { + if (!isa(gate)) { + // No extra checks for non quantum.custom ops + return; + } + + ValueRange inQubits = gate.getInQubits(); + auto parentGate = dyn_cast_or_null(inQubits[0].getDefiningOp()); + + if (!parentGate) { + this->setVerifierResult(false); + return; + } + + if (!verifyParentGateName(gate, parentGate)) { + this->setVerifierResult(false); + return; + } + } + + bool verifyParentGateName(OpType op, OpType parentOp) const + { + // If OpType is quantum.custom, also verify that parent gate has the + // same gate name. + StringRef opGateName = cast(op).getGateName(); + StringRef parentGateName = cast(parentOp).getGateName(); + return opGateName == parentGateName; + } +}; } // namespace catalyst From 2f95838a48a5def6e304900a620bc5c7dd133ceb Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Thu, 10 Oct 2024 14:09:21 -0400 Subject: [PATCH 32/45] use aggressive for named gates --- mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index ababbb7858..a981441d4f 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -49,18 +49,15 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern avpga(op); if (!avpga.getVerifierResult()) { return failure(); } + // Replace uses ValueRange InQubits = op.getInQubits(); auto ParentOp = dyn_cast_or_null(InQubits[0].getDefiningOp()); - if (ParentOp.getGateName() != OpGateName) { - return failure(); - } - - // Replace uses ValueRange simplifiedVal = ParentOp.getInQubits(); rewriter.replaceOp(op, simplifiedVal); return success(); From 9bc658bc35e53b9c4dec2d50c2fa2ff4c1e7ac85 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Thu, 10 Oct 2024 17:07:25 -0400 Subject: [PATCH 33/45] add multirz --- .../Transforms/ChainedSelfInversePatterns.cpp | 7 +++ .../Transforms/VerifyParentGateAnalysis.hpp | 1 + mlir/test/Quantum/ChainedSelfInverseTest.mlir | 48 +++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index a981441d4f..b8d6229b74 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -142,8 +142,15 @@ namespace quantum { void populateSelfInversePatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext(), 1); + + // TODO: better organize the quantum dialect + // There is an interface `QuantumGate` for all the unitary gate operations, + // but interfaces cannot be accepted by pattern matchers, since pattern + // matchers require the target operations to have concrete names in the IR. patterns.add>(patterns.getContext(), 1); patterns.add>(patterns.getContext(), 1); + patterns.add>(patterns.getContext(), 1); + //patterns.add>(patterns.getContext(), 1); } } // namespace quantum diff --git a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp index 8f6889728e..9bd22b482d 100644 --- a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp +++ b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp @@ -158,6 +158,7 @@ class AggressiveVerifyParentGateAnalysis : public VerifyParentGateAnalysis (!quantum.bit, !quantum.bit, !quantum. return %out_qubits_1#0, %out_qubits_1#1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit } + + +// ----- + + +// test quantum.multirz labeled with adjoint attribute + +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse(%arg0: f64) -> (!quantum.bit, !quantum.bit, !quantum.bit) { + // CHECK: quantum.alloc + // CHECK: [[IN0:%.+]] = quantum.extract + // CHECK: [[IN1:%.+]] = quantum.extract + // CHECK: [[IN2:%.+]] = quantum.extract + %0 = quantum.alloc( 3) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %3 = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit + + %mrz:3 = quantum.multirz(%arg0) %1, %2, %3 : !quantum.bit, !quantum.bit, !quantum.bit + %mrz_out:3 = quantum.multirz(%arg0) %mrz#0, %mrz#1, %mrz#2 {adjoint} : !quantum.bit, !quantum.bit, !quantum.bit + + // CHECK-NOT: quantum.multirz + // CHECK: return [[IN0]], [[IN1]], [[IN2]] + return %mrz_out#0, %mrz_out#1, %mrz_out#2 : !quantum.bit, !quantum.bit, !quantum.bit +} + +// ----- + + +// test quantum.multirz but wrong wire order + +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse(%arg0: f64) -> (!quantum.bit, !quantum.bit, !quantum.bit) { + // CHECK: quantum.alloc + // CHECK: [[IN0:%.+]] = quantum.extract + // CHECK: [[IN1:%.+]] = quantum.extract + // CHECK: [[IN2:%.+]] = quantum.extract + %0 = quantum.alloc( 3) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %3 = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit + + %mrz:3 = quantum.multirz(%arg0) %1, %2, %3 : !quantum.bit, !quantum.bit, !quantum.bit + %mrz_out:3 = quantum.multirz(%arg0) %mrz#1, %mrz#2, %mrz#0 {adjoint} : !quantum.bit, !quantum.bit, !quantum.bit + + // CHECK: quantum.multirz + return %mrz_out#0, %mrz_out#1, %mrz_out#2 : !quantum.bit, !quantum.bit, !quantum.bit +} From 0660b682dd9e355500db5d505417feba291f9e79 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 11 Oct 2024 11:10:45 -0400 Subject: [PATCH 34/45] changelog --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index dc6d5f876c..d144a96d22 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -209,7 +209,7 @@ * Samples on lightning.qubit/kokkos can now be seeded with `qjit(seed=...)`. [(#1164)](https://github.com/PennyLaneAI/catalyst/pull/1164) -* The compiler pass `-remove-chained-self-inverse` can now also cancel adjoints of arbitrary unitaries (on top of just the named Hermitian gates). +* The compiler pass `-remove-chained-self-inverse` can now also cancel adjoints of arbitrary unitary operations (in addition to the the named Hermitian gates). [(#1186)](https://github.com/PennyLaneAI/catalyst/pull/1186)

Breaking changes

From a7cb5afa91169fc2d8ef72650a55f8b31dbedac7 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 11 Oct 2024 11:11:50 -0400 Subject: [PATCH 35/45] change aggressive name to VerifyParentGateAndNameAnalysis --- .../Quantum/Transforms/ChainedSelfInversePatterns.cpp | 10 ++++------ .../Quantum/Transforms/VerifyParentGateAnalysis.hpp | 6 +++--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index b8d6229b74..7ebcdacbaa 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -49,9 +49,8 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern avpga(op); - if (!avpga.getVerifierResult()) { + VerifyParentGateAndNameAnalysis vpga(op); + if (!vpga.getVerifierResult()) { return failure(); } @@ -105,8 +104,8 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); - AggressiveVerifyParentGateAnalysis avpga(op); - if (!avpga.getVerifierResult()) { + VerifyParentGateAndNameAnalysis vpga(op); + if (!vpga.getVerifierResult()) { return failure(); } @@ -150,7 +149,6 @@ void populateSelfInversePatterns(RewritePatternSet &patterns) patterns.add>(patterns.getContext(), 1); patterns.add>(patterns.getContext(), 1); patterns.add>(patterns.getContext(), 1); - //patterns.add>(patterns.getContext(), 1); } } // namespace quantum diff --git a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp index 9bd22b482d..fbc4a0adad 100644 --- a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp +++ b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp @@ -33,7 +33,7 @@ // 3. If the gates are controlled, both gates' control wires and values // must be the same. The control wires must be in the same order // -// On top of the above, we also provide a AggressiveVerifyParentGateAnalysis, +// On top of the above, we also provide a VerifyParentGateAndNameAnalysis, // which also checks: // 4. If the gates are quantum.custom, then both gates have the same name. @@ -135,9 +135,9 @@ template class VerifyParentGateAnalysis { }; template -class AggressiveVerifyParentGateAnalysis : public VerifyParentGateAnalysis { +class VerifyParentGateAndNameAnalysis : public VerifyParentGateAnalysis { public: - AggressiveVerifyParentGateAnalysis(OpType gate) : VerifyParentGateAnalysis(gate) + VerifyParentGateAndNameAnalysis(OpType gate) : VerifyParentGateAnalysis(gate) { if (!isa(gate)) { // No extra checks for non quantum.custom ops From 02dc927ff128c4b2c865a388d7a3702b25a3af9d Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 11 Oct 2024 11:26:56 -0400 Subject: [PATCH 36/45] changelog grammar --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index d144a96d22..614c2ddfc4 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -209,7 +209,7 @@ * Samples on lightning.qubit/kokkos can now be seeded with `qjit(seed=...)`. [(#1164)](https://github.com/PennyLaneAI/catalyst/pull/1164) -* The compiler pass `-remove-chained-self-inverse` can now also cancel adjoints of arbitrary unitary operations (in addition to the the named Hermitian gates). +* The compiler pass `-remove-chained-self-inverse` can now also cancel adjoints of arbitrary unitary operations (in addition to the named Hermitian gates). [(#1186)](https://github.com/PennyLaneAI/catalyst/pull/1186)

Breaking changes

From 7ba892f49cbbcc600542b827fa8eb8c9f3fa6bef Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 11 Oct 2024 11:42:53 -0400 Subject: [PATCH 37/45] Add multirz test --- .../Transforms/MergeRotationsPatterns.cpp | 5 +- mlir/test/Quantum/MergeRotationsTest.mlir | 60 +++++++++++++++++++ 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index 802d99f57e..7cd45cce99 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -45,10 +45,7 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { ValueRange inQubits = op.getInQubits(); auto parentOp = dyn_cast_or_null(inQubits[0].getDefiningOp()); - if (!parentOp || parentOp.getGateName() != opGateName) - return failure(); - - VerifyParentGateAnalysis vpga(op); + VerifyParentGateAndNameAnalysis vpga(op); if (!vpga.getVerifierResult()) { return failure(); } diff --git a/mlir/test/Quantum/MergeRotationsTest.mlir b/mlir/test/Quantum/MergeRotationsTest.mlir index ca5d1b1344..c1d55f67be 100644 --- a/mlir/test/Quantum/MergeRotationsTest.mlir +++ b/mlir/test/Quantum/MergeRotationsTest.mlir @@ -193,3 +193,63 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum // CHECK: return [[ret]]#0, [[ret]]#1 return %5#0, %5#1 : !quantum.bit, !quantum.bit } + + +// ----- + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[qubit2:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + + // CHECK: [[theta1:%.+]] = arith.addf %arg1, %arg2 : f64 + // CHECK: [[theta2:%.+]] = arith.addf %arg0, [[theta1]] : f64 + // CHECK: [[ret:%.+]]:2 = quantum.multirz([[theta2]]) [[qubit1]], [[qubit2]] : !quantum.bit, !quantum.bit + // CHECK-NOT: quantum.multirz + %3:2 = quantum.multirz (%arg0) %1, %2 : !quantum.bit, !quantum.bit + %4:2 = quantum.multirz (%arg1) %3#0, %3#1 : !quantum.bit, !quantum.bit + %5:2 = quantum.multirz (%arg2) %4#0, %4#1 : !quantum.bit, !quantum.bit + // CHECK: return [[ret]]#0, [[ret]]#1 + return %5#0, %5#1 : !quantum.bit, !quantum.bit +} + +// ----- + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[qubit2:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %3 = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit + // CHECK-NOT: arith.addf + // CHECK: quantum.multirz + // CHECK: quantum.multirz + // CHECK-NOT: quantum.multirz + %4:2 = quantum.multirz (%arg0) %1, %2 : !quantum.bit, !quantum.bit + %5:2 = quantum.multirz (%arg1) %4#0, %3 : !quantum.bit, !quantum.bit + return %5#0, %5#1 : !quantum.bit, !quantum.bit +} + +// ----- + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[qubit2:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %3 = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit + // CHECK-NOT: arith.addf + // CHECK: quantum.multirz + // CHECK: quantum.multirz + // CHECK-NOT: quantum.multirz + %4:2 = quantum.multirz (%arg0) %1, %2 : !quantum.bit, !quantum.bit + %5 = quantum.multirz (%arg1) %4#0 : !quantum.bit + return %5, %4#1 : !quantum.bit, !quantum.bit +} \ No newline at end of file From 4ae4281a0ddf97dc1dadb0757a2e4a2e2ed1171b Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 11 Oct 2024 13:02:25 -0400 Subject: [PATCH 38/45] Update doc --- frontend/catalyst/passes.py | 83 +++++++++++++++- .../test/lit/test_peephole_optimizations.py | 97 ++++++++++++++++--- .../pytest/test_peephole_optimizations.py | 54 +++++++++-- 3 files changed, 209 insertions(+), 25 deletions(-) diff --git a/frontend/catalyst/passes.py b/frontend/catalyst/passes.py index 70f4c86e2c..1716137cfa 100644 --- a/frontend/catalyst/passes.py +++ b/frontend/catalyst/passes.py @@ -310,6 +310,87 @@ def wrapper(*args, **kwrags): return fn_clone +def merge_rotations(fn=None): + """ + Specify that the ``-merge-rotations`` MLIR compiler pass + for merging roations (peephole) will be applied. + + The full list of supported gates are as follows: + + :class:`qml.RX `, + :class:`qml.CRX `, + :class:`qml.RY `, + :class:`qml.CRY `, + :class:`qml.RZ `, + :class:`qml.CRZ `, + :class:`qml.PhaseShift `, + :class:`qml.ControlledPhaseShift `, + :class:`qml.Rot `, + :class:`qml.CRot `, + :class:`qml.MultiRZ `. + + + .. note:: + + Unlike PennyLane :doc:`circuit transformations `, + the QNode itself will not be changed or transformed by applying these + decorators. + + As a result, circuit inspection tools such as :func:`~.draw` will continue + to display the circuit as written in Python. + + Args: + fn (QNode): the QNode to apply the cancel inverses compiler pass to + + Returns: + ~.QNode: + + **Example** + + In this example the three :class:`qml.RX ` will be merged in a single + one with the sum of angles as parameter. + + .. code-block:: python + + from catalyst.debug import get_compilation_stage + from catalyst.passes import merge_rotations + + dev = qml.device("lightning.qubit", wires=1) + + @qjit(keep_intermediate=True) + @merge_rotations + @qml.qnode(dev) + def circuit(x: float): + qml.RX(x, wires=0) + qml.RX(0.1, wires=0) + qml.RX(x**2, wires=0) + return qml.expval(qml.PauliZ(0)) + + >>> circuit(0.54) + Array(0.5965506257017892, dtype=float64) + """ + if not isinstance(fn, qml.QNode): + raise TypeError(f"A QNode is expected, got the classical function {fn}") + + funcname = fn.__name__ + wrapped_qnode_function = fn.func + uniquer = str(_rename_to_unique()) + + def wrapper(*args, **kwrags): + if EvaluationContext.is_tracing(): + apply_registered_pass_p.bind( + pass_name="merge-rotations", + options=f"func-name={funcname}" + "_merge_rotations" + uniquer, + ) + return wrapped_qnode_function(*args, **kwrags) + + fn_clone = copy.copy(fn) + fn_clone.func = wrapper + fn_clone.__name__ = funcname + "_merge_rotations" + uniquer + + return fn_clone + + ## IMPL and helpers ## # pylint: disable=missing-function-docstring class _PipelineNameUniquer: @@ -332,7 +413,7 @@ def _rename_to_unique(): def _API_name_to_pass_name(): - return {"cancel_inverses": "remove-chained-self-inverse", "merge_rotations": "merge-rotation"} + return {"cancel_inverses": "remove-chained-self-inverse", "merge_rotations": "merge-rotations"} def _inject_transform_named_sequence(): diff --git a/frontend/test/lit/test_peephole_optimizations.py b/frontend/test/lit/test_peephole_optimizations.py index f5b67b67a8..27596cf69e 100644 --- a/frontend/test/lit/test_peephole_optimizations.py +++ b/frontend/test/lit/test_peephole_optimizations.py @@ -31,7 +31,7 @@ from catalyst import qjit from catalyst.debug import get_compilation_stage -from catalyst.passes import cancel_inverses, pipeline +from catalyst.passes import cancel_inverses, merge_rotations, pipeline def flush_peephole_opted_mlir_to_iostream(QJIT): @@ -86,7 +86,7 @@ def test_pipeline_lowering(): """ my_pipeline = { "cancel_inverses": {}, - "merge_rotations": {"my-option": "aloha"}, + "merge_rotations": {}, } @qjit(keep_intermediate=True) @@ -104,14 +104,14 @@ def test_pipeline_lowering_workflow(x): # CHECK: pass_name=remove-chained-self-inverse # CHECK: ] # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=test_pipeline_lowering_workflow_transformed0 my-option=aloha - # CHECK: pass_name=merge-rotation + # CHECK: options=func-name=test_pipeline_lowering_workflow_transformed0 + # CHECK: pass_name=merge-rotations # CHECK: ] print_jaxpr(test_pipeline_lowering_workflow, 1.2) # CHECK: transform.named_sequence @__transform_main # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0 my-option=aloha"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0"} # CHECK-NEXT: transform.yield print_mlir(test_pipeline_lowering_workflow, 1.2) @@ -160,13 +160,13 @@ def test_pipeline_lowering_keep_original_workflow(x): # CHECK: ] # CHECK: _:AbstractTransformMod() = apply_registered_pass[ # CHECK: options=func-name=f_transformed0 - # CHECK: pass_name=merge-rotation + # CHECK: pass_name=merge-rotations # CHECK: ] print_jaxpr(test_pipeline_lowering_keep_original_workflow, 1.2) # CHECK: transform.named_sequence @__transform_main # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_transformed0"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=f_transformed0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=f_transformed0"} # CHECK-NEXT: transform.yield print_mlir(test_pipeline_lowering_keep_original_workflow, 1.2) @@ -223,7 +223,7 @@ def h(x): # CHECK: ] # CHECK: _:AbstractTransformMod() = apply_registered_pass[ # CHECK: options=func-name=g_transformed0 - # CHECK: pass_name=merge-rotation + # CHECK: pass_name=merge-rotations # CHECK: ] # CHECK: _:AbstractTransformMod() = apply_registered_pass[ # CHECK: options=func-name=h_transformed1 @@ -231,15 +231,15 @@ def h(x): # CHECK: ] # CHECK: _:AbstractTransformMod() = apply_registered_pass[ # CHECK: options=func-name=h_transformed1 - # CHECK: pass_name=merge-rotation + # CHECK: pass_name=merge-rotations # CHECK: ] print_jaxpr(global_wf) # CHECK: transform.named_sequence @__transform_main # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_transformed0"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=g_transformed0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=g_transformed0"} # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_transformed1"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=h_transformed1"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=h_transformed1"} # CHECK-NEXT: transform.yield print_mlir(global_wf) @@ -301,20 +301,20 @@ def h(x): # CHECK: ] # CHECK: _:AbstractTransformMod() = apply_registered_pass[ # CHECK: options=func-name=g_transformed1 - # CHECK: pass_name=merge-rotation + # CHECK: pass_name=merge-rotations # CHECK: ] # CHECK: _:AbstractTransformMod() = apply_registered_pass[ # CHECK: options=func-name=h_transformed0 # CHECK-NOT: pass_name=remove-chained-self-inverse - # CHECK: pass_name=merge-rotation + # CHECK: pass_name=merge-rotations # CHECK: ] print_jaxpr(global_wf) # CHECK: transform.named_sequence @__transform_main # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_transformed1"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=g_transformed1"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=g_transformed1"} # CHECK-NOT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_transformed0"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=h_transformed0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=h_transformed0"} # CHECK-NEXT: transform.yield print_mlir(global_wf) @@ -563,3 +563,70 @@ def test_cancel_inverses_keep_original_workflow2(): test_cancel_inverses_keep_original() + + +# +# merge_rotations +# + + +def test_merge_rotations_tracing_and_lowering(): + """ + Test merge_rotations during tracing and lowering + """ + + @qjit + def test_merge_rotations_tracing_and_lowering_workflow(xx: float): + + @merge_rotations + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def f(x: float): + qml.RX(x, wires=0) + qml.RX(x, wires=0) + qml.Hadamard(wires=0) + return qml.expval(qml.PauliZ(0)) + + @merge_rotations + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def g(x: float): + qml.RX(x, wires=0) + qml.RX(x, wires=0) + qml.Hadamard(wires=0) + return qml.expval(qml.PauliZ(0)) + + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def h(x: float): + qml.RX(x, wires=0) + qml.RX(x, wires=0) + qml.Hadamard(wires=0) + return qml.expval(qml.PauliZ(0)) + + _f = f(xx) + _g = g(xx) + _h = h(xx) + return _f, _g, _h + + # CHECK: transform_named_sequence + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=f_merge_rotations0 + # CHECK: pass_name=merge-rotations + # CHECK: ] + # CHECK: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK: options=func-name=g_merge_rotations1 + # CHECK: pass_name=merge-rotations + # CHECK: ] + # CHECK-NOT: _:AbstractTransformMod() = apply_registered_pass[ + # CHECK-NOT: options=func-name=h_merge_rotations + # CHECK-NOT: pass_name=merge-rotations + print_jaxpr(test_merge_rotations_tracing_and_lowering_workflow, 1.1) + + # CHECK: module @test_merge_rotations_tracing_and_lowering_workflow + # CHECK: transform.named_sequence @__transform_main + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=f_merge_rotations0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=g_merge_rotations1"} + # CHECK-NOT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=h_merge_rotations"} + # CHECK-NEXT: transform.yield + print_mlir(test_merge_rotations_tracing_and_lowering_workflow, 1.1) + + +test_merge_rotations_tracing_and_lowering() diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index 40bbf0fb28..6b552edfc0 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -19,7 +19,7 @@ import pytest from catalyst import qjit -from catalyst.passes import cancel_inverses, pipeline +from catalyst.passes import cancel_inverses, merge_rotations, pipeline # pylint: disable=missing-function-docstring @@ -63,6 +63,42 @@ def reference(x): assert np.allclose(workflow()[1], reference(theta)) +@pytest.mark.parametrize("theta", [42.42]) +def test_cancel_inverses_functionality(theta, backend): + + @qjit + def workflow(): + @qml.qnode(qml.device(backend, wires=1)) + def f(x): + qml.RX(x, wires=0) + qml.RX(x, wires=0) + qml.Hadamard(wires=0) + qml.Hadamard(wires=0) + return qml.probs() + + @merge_rotations + @qml.qnode(qml.device(backend, wires=1)) + def g(x): + qml.RX(x, wires=0) + qml.RX(x, wires=0) + qml.Hadamard(wires=0) + qml.Hadamard(wires=0) + return qml.probs() + + return f(theta), g(theta) + + @qml.qnode(qml.device("default.qubit", wires=1)) + def reference(x): + qml.RX(x, wires=0) + qml.RX(x, wires=0) + qml.Hadamard(wires=0) + qml.Hadamard(wires=0) + return qml.probs() + + assert np.allclose(workflow()[0], workflow()[1]) + assert np.allclose(workflow()[1], reference(theta)) + + @pytest.mark.parametrize("theta", [42.42]) def test_cancel_inverses_functionality_outside_qjit(theta, backend): @@ -99,13 +135,14 @@ def test_pipeline_functionality(capfd, theta, backend): """ my_pipeline = { "cancel_inverses": {}, - "merge_rotations": {"my-option": "aloha"}, + "merge_rotations": {}, } @qjit def workflow(): @qml.qnode(qml.device(backend, wires=2)) def f(x): + qml.RX(0.1, wires=[0]) qml.RX(x, wires=[0]) qml.Hadamard(wires=[1]) qml.Hadamard(wires=[1]) @@ -119,13 +156,6 @@ def f(x): res = workflow() assert np.allclose(res[0], res[1]) - # TODO: the boilerplate merge rotation pass prints out different messages based on - # the pass option. - # The purpose is to test the integration of pass options with pipeline decorator. - # Remove the string check when merge rotation becomes the actual merge rotation pass. - output_message = capfd.readouterr().err - assert output_message == "merge rotation pass, aloha!\n" - ### Test bad usages of pass decorators ### def test_cancel_inverses_bad_usages(): @@ -149,6 +179,12 @@ def classical_func(): ): cancel_inverses(classical_func) + with pytest.raises( + TypeError, + match="A QNode is expected, got the classical function", + ): + merge_rotations(classical_func) + test_cancel_inverses_not_on_qnode() From 1b970dad55e2cde2da611ed351808d3c6cff3ae6 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 11 Oct 2024 13:55:02 -0400 Subject: [PATCH 39/45] Update mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp Co-authored-by: paul0403 <79805239+paul0403@users.noreply.github.com> --- mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index 7cd45cce99..9fab424ae8 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -45,7 +45,7 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { ValueRange inQubits = op.getInQubits(); auto parentOp = dyn_cast_or_null(inQubits[0].getDefiningOp()); - VerifyParentGateAndNameAnalysis vpga(op); + VerifyParentGateAndNameAnalysis vpga(op); if (!vpga.getVerifierResult()) { return failure(); } From ca6df1be60ac1bc986895f7689fb7326669c1433 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 11 Oct 2024 13:58:19 -0400 Subject: [PATCH 40/45] Update --- mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index 7cd45cce99..b69eb991db 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -45,7 +45,7 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { ValueRange inQubits = op.getInQubits(); auto parentOp = dyn_cast_or_null(inQubits[0].getDefiningOp()); - VerifyParentGateAndNameAnalysis vpga(op); + VerifyParentGateAndNameAnalysis vpga(op); if (!vpga.getVerifierResult()) { return failure(); } @@ -69,7 +69,6 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { parentInCtrlQubits, parentInCtrlValues); op.replaceAllUsesWith(mergeOp); - op.erase(); parentOp.erase(); return success(); @@ -110,7 +109,6 @@ struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern { sumParam, parentInQubits, nullptr, parentInCtrlQubits, parentInCtrlValues); op.replaceAllUsesWith(mergeOp); - op.erase(); parentOp.erase(); return success(); From bd0ddb6be3404d53ecb48787a74eecf170212222 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 11 Oct 2024 14:55:52 -0400 Subject: [PATCH 41/45] Update --- doc/releases/changelog-dev.md | 37 +++++++++++++++++++ .../Catalyst/Transforms/RegisterAllPasses.cpp | 1 - .../Transforms/MergeRotationsPatterns.cpp | 2 + mlir/test/Quantum/MergeRotationsTest.mlir | 26 +++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 614c2ddfc4..d42035bd4d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -117,6 +117,43 @@ Available MLIR passes are now documented and available within the [catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes). +* A peephole merge rotations pass is now available in MLIR. It can be added to `catalyst.passes.pipeline`, or the + Python function `catalyst.passes.merge_rotations` can be directly called on a `QNode`. + [(#1162)](https://github.com/PennyLaneAI/catalyst/pull/1162) + + Using the pipeline, one can run: + + ```python + from catalys.passes import pipeline + + my_passes = { + "merge_rotations": {} + } + + @qjit(circuit_transform_pipeline=my_passes) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def g(x: float): + qml.RX(x, wires=0) + qml.RX(x, wires=0) + qml.Hadamard(wires=0) + return qml.expval(qml.PauliZ(0)) + ``` + + Using the python function, one can run: + + ```python + from catalys.passes import merge_rotations + + @qjit + @merge_rotations + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def g(x: float): + qml.RX(x, wires=0) + qml.RX(x, wires=0) + qml.Hadamard(wires=0) + return qml.expval(qml.PauliZ(0)) + ``` + * Catalyst Autograph now supports updating a single index or a slice of JAX arrays using Python's array assignment operator syntax. [(#769)](https://github.com/PennyLaneAI/catalyst/pull/769) diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 6e683f3fe9..d3c347fabf 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -41,7 +41,6 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createInlineNestedModulePass); mlir::registerPass(catalyst::createMemrefCopyToLinalgCopyPass); mlir::registerPass(catalyst::createMemrefToLLVMWithTBAAPass); - mlir::registerPass(catalyst::createMergeRotationsPass); mlir::registerPass(catalyst::createMitigationLoweringPass); mlir::registerPass(catalyst::createQnodeToAsyncLoweringPass); mlir::registerPass(catalyst::createQuantumBufferizationPass); diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index b69eb991db..9fab424ae8 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -69,6 +69,7 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { parentInCtrlQubits, parentInCtrlValues); op.replaceAllUsesWith(mergeOp); + op.erase(); parentOp.erase(); return success(); @@ -109,6 +110,7 @@ struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern { sumParam, parentInQubits, nullptr, parentInCtrlQubits, parentInCtrlValues); op.replaceAllUsesWith(mergeOp); + op.erase(); parentOp.erase(); return success(); diff --git a/mlir/test/Quantum/MergeRotationsTest.mlir b/mlir/test/Quantum/MergeRotationsTest.mlir index c1d55f67be..aaf646eee1 100644 --- a/mlir/test/Quantum/MergeRotationsTest.mlir +++ b/mlir/test/Quantum/MergeRotationsTest.mlir @@ -252,4 +252,30 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum %4:2 = quantum.multirz (%arg0) %1, %2 : !quantum.bit, !quantum.bit %5 = quantum.multirz (%arg1) %4#0 : !quantum.bit return %5, %4#1 : !quantum.bit, !quantum.bit +} + +// ----- + +func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit, !quantum.bit) { + // CHECK: [[true:%.+]] = llvm.mlir.constant + // CHECK: [[false:%.+]] = llvm.mlir.constant + %true = llvm.mlir.constant (1 : i1) :i1 + %false = llvm.mlir.constant (0 : i1) :i1 + + // CHECK: quantum.alloc + // CHECK: [[qubit0:%.+]] = quantum.extract {{.+}}[ 0] + // CHECK: [[qubit1:%.+]] = quantum.extract {{.+}}[ 1] + // CHECK: [[qubit2:%.+]] = quantum.extract {{.+}}[ 2] + %reg = quantum.alloc( 4) : !quantum.reg + %0 = quantum.extract %reg[ 0] : !quantum.reg -> !quantum.bit + %1 = quantum.extract %reg[ 1] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %reg[ 2] : !quantum.reg -> !quantum.bit + // CHECK: [[angle0:%.+]] = arith.addf %arg0, %arg1 : f64 + // CHECK: [[angle1:%.+]] = arith.addf %arg1, %arg2 : f64 + // CHECK: [[angle2:%.+]] = arith.addf %arg2, %arg0 : f64 + // CHECK: [[ret:%.+]], [[ctrlret:%.+]]:2 = quantum.custom "Rot"([[angle0]], [[angle1]], [[angle2]]) [[qubit0]] ctrls([[qubit1]], [[qubit2]]) ctrlvals([[true]], [[false]]) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %out_qubits, %out_ctrl_qubits:2 = quantum.custom "Rot"(%arg0, %arg1, %arg2) %0 ctrls(%1, %2) ctrlvals(%true, %false) : !quantum.bit ctrls !quantum.bit, !quantum.bit + %out_qubits_1, %out_ctrl_qubits_1:2 = quantum.custom "Rot"(%arg1, %arg2, %arg0) %out_qubits ctrls(%out_ctrl_qubits#0, %out_ctrl_qubits#1) ctrlvals(%true, %false) : !quantum.bit ctrls !quantum.bit, !quantum.bit + + return %out_qubits_1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit } \ No newline at end of file From 01cd3f5043f5f7e3243aeb0cf4cffa57607dae6c Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 11 Oct 2024 16:25:35 -0400 Subject: [PATCH 42/45] Remove erase --- mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index 9fab424ae8..0edd29b2cf 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -69,8 +69,6 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { parentInCtrlQubits, parentInCtrlValues); op.replaceAllUsesWith(mergeOp); - op.erase(); - parentOp.erase(); return success(); } @@ -110,9 +108,6 @@ struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern { sumParam, parentInQubits, nullptr, parentInCtrlQubits, parentInCtrlValues); op.replaceAllUsesWith(mergeOp); - op.erase(); - parentOp.erase(); - return success(); } }; From def5b3aa60e2728beb86d298b669c2ad9f869d5a Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 11 Oct 2024 16:26:53 -0400 Subject: [PATCH 43/45] Update --- mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index 0edd29b2cf..93fd2546c8 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -104,9 +104,9 @@ struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern { mlir::Value sumParam = rewriter.create(loc, parentTheta, theta).getResult(); - auto mergeOp = rewriter.create(loc, outQubitsTypes, outQubitsCtrlTypes, - sumParam, parentInQubits, nullptr, - parentInCtrlQubits, parentInCtrlValues); + auto mergeOp = rewriter.create(loc, outQubitsTypes, outQubitsCtrlTypes, sumParam, + parentInQubits, nullptr, parentInCtrlQubits, + parentInCtrlValues); op.replaceAllUsesWith(mergeOp); return success(); } From 0213ae7605523787cd2c5449c7e7faeb6e143df7 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 11 Oct 2024 16:29:03 -0400 Subject: [PATCH 44/45] Pylint --- frontend/test/pytest/test_peephole_optimizations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index 6b552edfc0..6f5a7a1077 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -64,7 +64,7 @@ def reference(x): @pytest.mark.parametrize("theta", [42.42]) -def test_cancel_inverses_functionality(theta, backend): +def test_merge_rotation_functionality(theta, backend): @qjit def workflow(): @@ -128,7 +128,7 @@ def g(x): @pytest.mark.parametrize("theta", [42.42]) -def test_pipeline_functionality(capfd, theta, backend): +def test_pipeline_functionality(theta, backend): """ Test that the @pipeline decorator does not change functionality when all the passes in the pipeline does not change functionality. From 39c8a7ae47f3b203d4feed85a5f2942b6ca7ebd2 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 11 Oct 2024 17:30:36 -0400 Subject: [PATCH 45/45] Draft --- mlir/include/Quantum/IR/QuantumOps.td | 1 + mlir/lib/Quantum/IR/QuantumOps.cpp | 30 +++++++++++++++++++ .../lib/Quantum/Transforms/merge_rotation.cpp | 1 + 3 files changed, 32 insertions(+) diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index cf8532fdfc..904d865e5a 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -384,6 +384,7 @@ def CustomOp : UnitaryGate_Op<"custom", [DifferentiableGate, NoMemoryEffect, return getParams(); } }]; + let hasCanonicalizeMethod = 1; } def GlobalPhaseOp : UnitaryGate_Op<"gphase", [DifferentiableGate, AttrSizedOperandSegments]> { diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 9b86f63d7b..73cc3a5877 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include @@ -34,6 +36,34 @@ using namespace catalyst::quantum; //===----------------------------------------------------------------------===// // Quantum op canonicalizers. //===----------------------------------------------------------------------===// +static const mlir::StringSet<> hermitianOps = {"Hadamard", "PauliX", "PauliY", "PauliZ", "CNOT", + "CY", "CZ", "SWAP", "Toffoli"}; +static const mlir::StringSet<> rotationsOps = {"RX", "RY", "RZ", "PhaseShift", "Rot", + "CRX", "CRY", "CRZ", "ControlledPhaseShift", "CRot"}; +LogicalResult CustomOp::canonicalize(CustomOp op, mlir::PatternRewriter &rewriter) +{ + if (op.getAdjoint()) { + auto name = op.getGateName(); + if (hermitianOps.contains(name)) { + op.setAdjoint(false); + return success(); + } + else if (rotationsOps.contains(name)) { + auto params = op.getParams(); + SmallVector paramsMinus; + for (auto param : params) { + rewriter.create(op.getLoc(), param); + } + auto adjointOp = rewriter.create(op.getLoc(), op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), + paramsMinus, op.getInQubits(), name, nullptr, + op.getInCtrlQubits(), op.getInCtrlValues()); + rewriter.replaceOp(op, adjointOp.getResults()); + return success(); + } + return failure(); + }; + return failure(); +} LogicalResult DeallocOp::canonicalize(DeallocOp dealloc, mlir::PatternRewriter &rewriter) { diff --git a/mlir/lib/Quantum/Transforms/merge_rotation.cpp b/mlir/lib/Quantum/Transforms/merge_rotation.cpp index 45819b0247..c19f62a2d0 100644 --- a/mlir/lib/Quantum/Transforms/merge_rotation.cpp +++ b/mlir/lib/Quantum/Transforms/merge_rotation.cpp @@ -62,6 +62,7 @@ struct MergeRotationsPass : impl::MergeRotationsPassBase { } RewritePatternSet patterns(&getContext()); + catalyst::quantum::CustomOp::getCanonicalizationPatterns(patterns, &getContext()); populateMergeRotationsPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(targetfunc, std::move(patterns)))) {