Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] New compiler driver interface #1208

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build-wheel-linux-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ jobs:
-DLLVM_ENABLE_ZSTD=FORCE_ON \
-DLLVM_ENABLE_LLD=ON

cmake --build quantum-build --target check-dialects compiler_driver
cmake --build quantum-build --target check-dialects compiler_driver qcc

- name: Build wheel
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-wheel-macos-arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ jobs:
-DLLVM_ENABLE_LLD=OFF \
-DLLVM_DIR=$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm

cmake --build quantum-build --target check-dialects compiler_driver
cmake --build quantum-build --target check-dialects compiler_driver qcc

- name: Build wheel
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-wheel-macos-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ jobs:
-DLLVM_ENABLE_ZSTD=FORCE_ON \
-DLLVM_ENABLE_LLD=OFF

cmake --build quantum-build --target check-dialects compiler_driver
cmake --build quantum-build --target check-dialects compiler_driver qcc

- name: Build wheel
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ cmake -S mlir -B quantum-build -G Ninja \
-DLLVM_ENABLE_ZSTD=FORCE_ON \
-DLLVM_ENABLE_LLD=ON \
-DLLVM_DIR=/catalyst/llvm-build/lib/cmake/llvm
cmake --build quantum-build --target check-dialects compiler_driver
cmake --build quantum-build --target check-dialects compiler_driver qcc

# Copy files needed for the wheel where they are expected
cp /catalyst/runtime-build/lib/*/*/*/*/librtd* /catalyst/runtime-build/lib
Expand Down
43 changes: 35 additions & 8 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,6 @@ class Compiler:
@debug_logger_init
def __init__(self, options: Optional[CompileOptions] = None):
self.options = options if options is not None else CompileOptions()
self.last_compiler_output = None

@debug_logger
def run_from_ir(self, ir: str, module_name: str, workspace: Directory):
Expand Down Expand Up @@ -601,7 +600,6 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory):
else:
output_filename = filename

self.last_compiler_output = compiler_output
return output_filename, out_IR

@debug_logger
Expand Down Expand Up @@ -630,20 +628,49 @@ def run(self, mlir_module, *args, **kwargs):
)

@debug_logger
def get_output_of(self, pipeline) -> Optional[str]:
def get_output_of(self, pipeline, workspace) -> Optional[str]:
"""Get the output IR of a pipeline.
Args:
pipeline (str): name of pass class

Returns
(Optional[str]): output IR
"""
if not self.last_compiler_output or not self.last_compiler_output.get_pipeline_output(
pipeline
):
file_content = None
for dirpath, _, filenames in os.walk(str(workspace)):
filenames = [f for f in filenames if f.endswith(".mlir") or f.endswith(".ll")]
if not filenames:
break
filenames_no_ext = [os.path.splitext(f)[0] for f in filenames]
if pipeline == "mlir":
# Sort files and pick the first one
selected_file = [
sorted(filenames)[0],
]
elif pipeline == "last":
# Sort files and pick the last one
selected_file = [
sorted(filenames)[-1],
]
else:
selected_file = [
f
for f, name_no_ext in zip(filenames, filenames_no_ext)
if pipeline in name_no_ext
]
if len(selected_file) != 1:
msg = f"Attempting to get output for pipeline: {pipeline},"
msg += " but no or more than one file was found.\n"
raise CompileError(msg)
filename = selected_file[0]

full_path = os.path.join(dirpath, filename)
with open(full_path, "r", encoding="utf-8") as file:
file_content = file.read()

if file_content is None:
msg = f"Attempting to get output for pipeline: {pipeline},"
msg += " but no file was found.\n"
msg += "Are you sure the file exists?"
raise CompileError(msg)

return self.last_compiler_output.get_pipeline_output(pipeline)
return file_content
4 changes: 1 addition & 3 deletions frontend/catalyst/debug/compiler_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ def func(x: float):
if not isinstance(fn, catalyst.QJIT):
raise TypeError(f"First argument needs to be a 'QJIT' object, got a {type(fn)}.")

if stage == "last":
return fn.compiler.last_compiler_output.get_output_ir()
return fn.compiler.get_output_of(stage)
return fn.compiler.get_output_of(stage, fn.workspace)


@debug_logger
Expand Down
23 changes: 11 additions & 12 deletions frontend/test/pytest/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_attempts_to_get_inexistent_intermediate_file(self):
"""Test the return value if a user requests an intermediate file that doesn't exist."""
compiler = Compiler()
with pytest.raises(CompileError, match="Attempting to get output for pipeline"):
compiler.get_output_of("inexistent-file")
compiler.get_output_of("inexistent-file", ".")

def test_runtime_error(self, backend):
"""Test with non-default flags."""
Expand Down Expand Up @@ -222,15 +222,15 @@ def workflow():

compiler = workflow.compiler
with pytest.raises(CompileError, match="Attempting to get output for pipeline"):
compiler.get_output_of("EmptyPipeline1")
assert compiler.get_output_of("HLOLoweringPass")
assert compiler.get_output_of("QuantumCompilationPass")
compiler.get_output_of("EmptyPipeline1", workflow.workspace)
assert compiler.get_output_of("HLOLoweringPass", workflow.workspace)
assert compiler.get_output_of("QuantumCompilationPass", workflow.workspace)
with pytest.raises(CompileError, match="Attempting to get output for pipeline"):
compiler.get_output_of("EmptyPipeline2")
assert compiler.get_output_of("BufferizationPass")
assert compiler.get_output_of("MLIRToLLVMDialect")
compiler.get_output_of("EmptyPipeline2", workflow.workspace)
assert compiler.get_output_of("BufferizationPass", workflow.workspace)
assert compiler.get_output_of("MLIRToLLVMDialect", workflow.workspace)
with pytest.raises(CompileError, match="Attempting to get output for pipeline"):
compiler.get_output_of("None-existing-pipeline")
compiler.get_output_of("None-existing-pipeline", workflow.workspace)
workflow.workspace.cleanup()

def test_print_nonexistent_stages(self, backend):
Expand All @@ -243,7 +243,7 @@ def workflow():
return qml.state()

with pytest.raises(CompileError, match="Attempting to get output for pipeline"):
workflow.compiler.get_output_of("None-existing-pipeline")
workflow.compiler.get_output_of("None-existing-pipeline", workflow.workspace)
workflow.workspace.cleanup()

def test_workspace(self):
Expand Down Expand Up @@ -305,10 +305,9 @@ def circuit():
compiled.compile()

assert "Failed to lower MLIR module" in e.value.args[0]
assert "While processing 'TestPass' pass of the 'PipelineB' pipeline" in e.value.args[0]
assert "PipelineA" not in e.value.args[0]
assert "While processing 'TestPass' pass " in e.value.args[0]
assert "Trace" not in e.value.args[0]
assert isfile(os.path.join(str(compiled.workspace), "2_PipelineB_FAILED.mlir"))
assert isfile(os.path.join(str(compiled.workspace), "2_TestPass_FAILED.mlir"))
compiled.workspace.cleanup()

with pytest.raises(CompileError) as e:
Expand Down
6 changes: 1 addition & 5 deletions frontend/test/pytest/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,6 @@ def f(x):
"""Square function."""
return x**2

f.__name__ = f.__name__ + pass_name

jit_f = qjit(f, keep_intermediate=True)
data = 2.0
old_result = jit_f(data)
Expand All @@ -400,8 +398,6 @@ def f(x: float):
"""Square function."""
return x**2

f.__name__ = f.__name__ + pass_name

jit_f = qjit(f)
jit_grad_f = qjit(value_and_grad(jit_f), keep_intermediate=True)
jit_grad_f(3.0)
Expand All @@ -418,7 +414,7 @@ def f(x: float):
assert len(res) == 0

def test_get_compilation_stage_without_keep_intermediate(self):
"""Test if error is raised when using get_pipeline_output without keep_intermediate."""
"""Test if error is raised when using get_compilation_stage without keep_intermediate."""

@qjit
def f(x: float):
Expand Down
2 changes: 1 addition & 1 deletion mlir/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ dialects:
-DLLVM_ENABLE_ZLIB=$(ENABLE_ZLIB) \
-DLLVM_ENABLE_ZSTD=$(ENABLE_ZSTD)

cmake --build $(DIALECTS_BUILD_DIR) --target check-dialects quantum-lsp-server compiler_driver
cmake --build $(DIALECTS_BUILD_DIR) --target check-dialects quantum-lsp-server compiler_driver qcc

.PHONY: test
test:
Expand Down
38 changes: 25 additions & 13 deletions mlir/include/Driver/CompilerDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"

#include "Driver/Pipelines.h"

namespace catalyst {
namespace driver {

Expand All @@ -32,6 +35,12 @@ namespace driver {
// low-level messages, we might want to hide these.
enum class Verbosity { Silent = 0, Urgent = 1, Debug = 2, All = 3 };

enum SaveTemps { None, AfterPipeline, AfterPass };

enum Action { OPT, Translate, LLC, All };

enum InputType { MLIR, LLVMIR, OTHER };

/// Helper verbose reporting macro.
#define CO_MSG(opt, level, op) \
do { \
Expand All @@ -40,14 +49,6 @@ enum class Verbosity { Silent = 0, Urgent = 1, Debug = 2, All = 3 };
} \
} while (0)

/// Pipeline descriptor
struct Pipeline {
using Name = std::string;
using PassList = llvm::SmallVector<std::string>;
Name name;
PassList passes;
};

/// Optional parameters, for which we provide reasonable default values.
struct CompilerOptions {
/// The textual IR (MLIR or LLVM IR)
Expand All @@ -58,19 +59,21 @@ struct CompilerOptions {
mlir::StringRef moduleName;
/// The stream to output any error messages from MLIR/LLVM passes and translation.
llvm::raw_ostream &diagnosticStream;
/// If true, the driver will output the module at intermediate points.
bool keepIntermediate;
/// If specified, the driver will output the module after each pipeline or each pass.
SaveTemps keepIntermediate;
/// If true, the llvm.coroutine will be lowered.
bool asyncQnodes;
/// Sets the verbosity level to use when printing messages.
Verbosity verbosity;
/// Ordered list of named pipelines to execute, each pipeline is described by a list of MLIR
/// passes it includes.
std::vector<Pipeline> pipelinesCfg;
/// Whether to assume that the pipelines output is a valid LLVM dialect and lower it to LLVM IR
bool lowerToLLVM;
/// Specify that the compiler should start after reaching the given pass.
std::string checkpointStage;
/// Specify the loweting action to perform
Action loweringAction;
/// If true, the compiler will dump the pass pipeline that will be run.
bool dumpPassPipeline;

/// Get the destination of the object file at the end of compilation.
std::string getObjectFile() const
Expand Down Expand Up @@ -103,7 +106,16 @@ struct CompilerOutput {

/// Entry point to the MLIR portion of the compiler.
mlir::LogicalResult QuantumDriverMain(const catalyst::driver::CompilerOptions &options,
catalyst::driver::CompilerOutput &output);
catalyst::driver::CompilerOutput &output,
mlir::DialectRegistry &registry);

int QuantumDriverMainFromCL(int argc, char **argv);
int QuantumDriverMainFromArgs(const std::string &source, const std::string &workspace,
const std::string &moduleName, bool keepIntermediate,
bool asyncQNodes, bool verbose, bool lowerToLLVM,
const std::vector<catalyst::driver::Pipeline> &passPipelines,
const std::string &checkpointStage,
catalyst::driver::CompilerOutput &output);

namespace llvm {

Expand Down
61 changes: 61 additions & 0 deletions mlir/include/Driver/Pipelines.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "mlir/Pass/Pass.h"

namespace catalyst {
namespace driver {

void createEnforceRuntimeInvariantsPipeline(mlir::OpPassManager &pm);
void createHloLoweringPipeline(mlir::OpPassManager &pm);
void createQuantumCompilationPipeline(mlir::OpPassManager &pm);
void createBufferizationPipeline(mlir::OpPassManager &pm);
void createLLVMDialectLoweringPipeline(mlir::OpPassManager &pm);
void createDefaultCatalystPipeline(mlir::OpPassManager &pm);

void registerEnforceRuntimeInvariantsPipeline();
void registerHloLoweringPipeline();
void registerQuantumCompilationPipeline();
void registerBufferizationPipeline();
void registerLLVMDialectLoweringPipeline();
void registerDefaultCatalystPipeline();
void registerAllCatalystPipelines();

/// Pipeline descriptor
struct Pipeline {
using Name = std::string;
using PassList = llvm::SmallVector<std::string>;
using PipelineFunc = void (*)(mlir::OpPassManager &);
Name name;
PassList passes;
PipelineFunc registerFunc = nullptr;

mlir::LogicalResult addPipeline(mlir::OpPassManager &pm)
{
if (registerFunc) {
registerFunc(pm);
return mlir::success();
}
else {
return mlir::failure();
}
}
};

std::vector<Pipeline> getDefaultPipeline();

} // namespace driver
} // namespace catalyst
3 changes: 3 additions & 0 deletions mlir/lib/Driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ set(LLVM_LINK_COMPONENTS
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
set(LIBS
${dialect_libs}
${conversion_libs}
${extension_libs}
${translation_libs}
MLIROptLib
MLIRCatalyst
MLIRCatalystTransforms
Expand All @@ -42,6 +44,7 @@ set(LIBS
add_mlir_library(CatalystCompilerDriver
CompilerDriver.cpp
CatalystLLVMTarget.cpp
Pipelines.cpp

LINK_LIBS PRIVATE
${EXTERNAL_LIB}
Expand Down
Loading
Loading