Skip to content

Commit

Permalink
[torch-frontend] fix build.sh when building from llvm source (#475)
Browse files Browse the repository at this point in the history
Resolve #459
  • Loading branch information
qingyunqu authored Oct 26, 2024
1 parent 33a69e8 commit 43cd8bc
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 6 deletions.
2 changes: 1 addition & 1 deletion frontends/torch-frontend/MLIR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ if (NOT DEFINED MLIR_DIR OR "${MLIR_DIR}" STREQUAL "MLIR_DIR-NOTFOUND")
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_BUILD_TYPE=Release
-DLLVM_ENABLE_PROJECTS=mlir
-DLLVM_TARGETS_TO_BUILD=X86\;NVPTX\;AMDGPU
-DLLVM_TARGETS_TO_BUILD=X86
-DLLVM_ENABLE_ASSERTIONS=ON
-DMLIR_ENABLE_BINDINGS_PYTHON=ON

Expand Down
2 changes: 1 addition & 1 deletion frontends/torch-frontend/TorchMLIR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function(build_torch_mlir)
endif()

execute_process(
COMMAND ${CMAKE_COMMAND} --build ${TORCH_MLIR_BUILD_PATH} --target TorchMLIRPythonModules
COMMAND ${CMAKE_COMMAND} --build ${TORCH_MLIR_BUILD_PATH} --target TorchMLIRPythonModules TorchMLIRJITIRImporterPybind
RESULT_VARIABLE result
WORKING_DIRECTORY ${TORCH_MLIR_SRC_PATH}
)
Expand Down
8 changes: 4 additions & 4 deletions frontends/torch-frontend/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ cmake -S . \
-GNinja \
-DLLVM_EXTERNAL_LIT=$(which lit) \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_C_COMPILER=gcc \
-DCMAKE_CXX_COMPILER=g++ \
-DCMAKE_C_COMPILER=clang-11 \
-DCMAKE_CXX_COMPILER=clang++-11 \
-DTORCH_FRONTEND_ENABLE_JIT_IR_IMPORTER=${TORCH_FRONTEND_ENABLE_JIT_IR_IMPORTER} \
-DCMAKE_CXX_FLAGS="-Wno-unused-but-set-parameter -Wno-unused-but-set-variable" \
-DCMAKE_CXX_FLAGS="-fPIC" \
-DPython3_EXECUTABLE=$(which python3)

cmake --build ./build --target all

if [[ $TORCH_FRONTEND_TEST == "ON" ]]; then
python3 -m pip install -r test-requirements.txt
install_mhlo_tools
PYTHONPATH=build/python_packages/:build/torch_mlir_build/python_packages/torch_mlir python3 -m pytest torch-frontend/python/test
PYTHONPATH=build/python_packages/:build/torch_mlir_build/python_packages/torch_mlir TORCH_DISABLE_NATIVE_FUNCOL=1 python3 -m pytest -m "not attention_rewriter" torch-frontend/python/test
fi

popd
10 changes: 10 additions & 0 deletions frontends/torch-frontend/scripts/envsetup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ function apply_patches() {
popd
}

function apply_llvm_patches() {
pushd $TORCH_MLIR_ROOT/externals/llvm-project
git clean -fd .
for patch in ../../../llvm_patches/*; do
git apply $patch
done
popd
}

function prepare_for_build_with_prebuilt() {
pushd ${PROJ_DIR}
# install requirements
Expand All @@ -53,6 +62,7 @@ function prepare_for_build() {
git submodule update --init --recursive -f $TORCH_MLIR_ROOT

apply_patches
apply_llvm_patches
}

# python3 -m pip install --no-cache-dir torch==2.1.0+cu118 torchvision==0.16.0+cu118 -f https://download.pytorch.org/whl/torch_stable.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 35db138305d1..6a59b1642d30 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -75,8 +75,10 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);

/// Enable mlir-print-ir-after-all.
-MLIR_CAPI_EXPORTED void
-mlirPassManagerEnableIRPrinting(MlirPassManager passManager);
+MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
+ MlirPassManager passManager, bool printBeforePass, bool printAfterPass,
+ bool printModuleScope, bool printAfterOnlyOnChange,
+ bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags);

/// Enable / disable verify-each.
MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cdbfcfbc2295..e9a3780eb772 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -73,9 +73,34 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"Releases (leaks) the backing pass manager (testing)")
.def(
"enable_ir_printing",
- [](PyPassManager &passManager) {
- mlirPassManagerEnableIRPrinting(passManager.get());
+ [](PyPassManager &passManager, bool printBeforePass,
+ bool printAfterPass, bool printModuleScope,
+ bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure,
+ std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+ bool printGenericOpForm) {
+ MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+ if (largeElementsLimit)
+ mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
+ *largeElementsLimit);
+ if (enableDebugInfo)
+ mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
+ /*prettyForm=*/false);
+ if (printGenericOpForm)
+ mlirOpPrintingFlagsPrintGenericOpForm(flags);
+ mlirPassManagerEnableIRPrinting(passManager.get(), printBeforePass,
+ printAfterPass, printModuleScope,
+ printAfterOnlyOnChange,
+ printAfterOnlyOnFailure, flags);
+ mlirOpPrintingFlagsDestroy(flags);
},
+ py::arg("print_before_pass") = true,
+ py::arg("print_after_pass") = true,
+ py::arg("print_module_scope") = true,
+ py::arg("print_after_only_on_change") = true,
+ py::arg("print_after_only_on_failure") = false,
+ py::arg("large_elements_limit") = py::none(),
+ py::arg("enable_debug_info") = false,
+ py::arg("print_generic_op_form") = false,
"Enable mlir-print-ir-after-all.")
.def(
"enable_verifier",
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index d242baae99c0..d13a71bb19cf 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -13,6 +13,7 @@
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Utils.h"
#include "mlir/Pass/PassManager.h"
+#include <functional>
#include <optional>

using namespace mlir;
@@ -44,8 +45,23 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
return wrap(unwrap(passManager)->run(unwrap(op)));
}

-void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
- return unwrap(passManager)->enableIRPrinting();
+void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
+ bool printBeforePass, bool printAfterPass,
+ bool printModuleScope,
+ bool printAfterOnlyOnChange,
+ bool printAfterOnlyOnFailure,
+ MlirOpPrintingFlags flags) {
+ std::function<bool(Pass *, Operation *)> shouldPrintBeforePass = nullptr;
+ std::function<bool(Pass *, Operation *)> shouldPrintAfterPass = nullptr;
+ if (printBeforePass)
+ shouldPrintBeforePass = [](Pass *, Operation *) { return true; };
+ if (printAfterPass)
+ shouldPrintAfterPass = [](Pass *, Operation *) { return true; };
+ return unwrap(passManager)
+ ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
+ printModuleScope, printAfterOnlyOnChange,
+ printAfterOnlyOnFailure, /*out=*/llvm::errs(),
+ *unwrap(flags));
}

void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,5 @@ def test_rewrite_entry_func_name():
print(module_str)
assert "func.func @main" in module_str

if __name__ == "__main__":
test_debug()

0 comments on commit 43cd8bc

Please sign in to comment.