Skip to content

Commit

Permalink
[samples] Add big_matvec sample with pre-existing lowering config (#…
Browse files Browse the repository at this point in the history
…126)

Used to be part of an experiment but likely also useful as a sample for
others.
  • Loading branch information
zero9178 authored Aug 30, 2024
1 parent ec3b531 commit 7466aa6
Show file tree
Hide file tree
Showing 8 changed files with 352 additions and 21 deletions.
7 changes: 5 additions & 2 deletions codegen/compiler/src/Quidditch/Target/ConfigureForSnitch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,11 @@ void ConfigureForSnitch::runOnOperation() {
if (failed(setTranslationInfo(funcOp)))
return signalPassFailure();

if (failed(setRootConfig(funcOp, rootOperation)))
return signalPassFailure();
auto loweringConfig =
getLoweringConfig<quidditch::Snitch::LoweringConfigAttr>(rootOperation);
if (!loweringConfig)
if (failed(setRootConfig(funcOp, rootOperation)))
return signalPassFailure();

// The root configuration setting introduces `tensor.dim` operations.
// Resolve those away.
Expand Down
7 changes: 7 additions & 0 deletions codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,14 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
class QuidditchSession final
: public PluginSession<QuidditchSession, QuidditchTargetOptions,
PluginActivationPolicy::DefaultActivated> {
public:
static void registerGlobalDialects(DialectRegistry &registry) {
// Required to allow the 'quidditch_snitch' dialect to also be used in
// input IR without just being parsed as an 'OpaqueAttr'.
registry.insert<quidditch::Snitch::QuidditchSnitchDialect>();
}

private:
void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) override {
targets.add("quidditch_device",
[]() { return std::make_shared<QuidditchTargetDevice>(); });
Expand Down
40 changes: 21 additions & 19 deletions runtime/cmake/quidditch_module.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,14 @@ find_program(XDSL_OPT_PATH xdsl-opt
# The resulting library is the source file's name with the extension removed and
# '_module' appended.
function(quidditch_module)
cmake_parse_arguments(_RULE "LLVM;ASSERT_XDSL" "SRC;N_THREADS;DST" "FLAGS;DEPENDS" ${ARGN})
cmake_parse_arguments(_RULE "LLVM;ASSERT_XDSL" "SRC;DST" "FLAGS;DEPENDS" ${ARGN})

set(_MLIR_SRC "${_RULE_SRC}")
if (NOT _RULE_DST)
cmake_path(GET _MLIR_SRC STEM _RULE_DST)
set(_RULE_DST "${_RULE_DST}")
endif ()

if (NOT _RULE_N_THREADS)
set(_RULE_N_THREADS 8)
endif ()

get_filename_component(_MLIR_SRC "${_MLIR_SRC}" REALPATH)
set(_O_QUIDDITCH_FILE_NAME "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_DST}/${_RULE_DST}.o")
set(_O_LLVM_FILE_NAME "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_DST}/${_RULE_DST}_llvm.o")
Expand Down Expand Up @@ -95,7 +91,9 @@ function(quidditch_module)

list(APPEND _OUTPUT_FILES "${_O_QUIDDITCH_FILE_NAME}")
list(APPEND _OBJECT_FILES "${_O_QUIDDITCH_FILE_NAME}")
list(APPEND _OBJECT_FILES "${_O_LLVM_FILE_NAME}")
if (NOT _RULE_ASSERT_XDSL)
list(APPEND _OBJECT_FILES "${_O_LLVM_FILE_NAME}")
endif ()

string(REPLACE ".o" ".h" _STATIC_HDR_PATH "${_O_QUIDDITCH_FILE_NAME}")
list(APPEND _OUTPUT_FILES "${_STATIC_HDR_PATH}")
Expand All @@ -105,19 +103,23 @@ function(quidditch_module)
endif ()

string(REPLACE ".o" ".h" _STATIC_HDR_PATH "${_O_LLVM_FILE_NAME}")
list(APPEND _OUTPUT_FILES "${_STATIC_HDR_PATH}" "${_O_LLVM_FILE_NAME}")

list(APPEND _COMPILER_ARGS "--iree-hal-target-backends=llvm-cpu")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-debug-symbols=true")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-triple=riscv32-unknown-elf")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-cpu=generic-rv32")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-cpu-features=+m,+f,+d,+zfh")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-abi=ilp32d")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-float-abi=hard")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-link-embedded=false")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-link-static")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-number-of-threads=${_RULE_N_THREADS}")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-static-library-output-path=${_O_LLVM_FILE_NAME}")
list(APPEND _OUTPUT_FILES "${_STATIC_HDR_PATH}")

if (NOT _RULE_ASSERT_XDSL)
list(APPEND _COMPILER_ARGS "--iree-hal-target-backends=llvm-cpu")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-debug-symbols=true")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-triple=riscv32-unknown-elf")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-cpu=generic-rv32")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-cpu-features=+m,+f,+d,+zfh")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-abi=ilp32d")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-float-abi=hard")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-link-embedded=false")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-link-static")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-number-of-threads=8")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-static-library-output-path=${_O_LLVM_FILE_NAME}")

list(APPEND _OUTPUT_FILES "${_O_LLVM_FILE_NAME}")
endif ()

list(APPEND _COMPILER_ARGS "--output-format=vm-c")
list(APPEND _COMPILER_ARGS "--iree-vm-target-index-bits=32")
Expand Down
1 change: 1 addition & 0 deletions runtime/samples/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include(quidditch_module)

add_subdirectory(big_matvec)
add_subdirectory(nsnet2)
add_subdirectory(util)
add_subdirectory(vec_multiply)
12 changes: 12 additions & 0 deletions runtime/samples/big_matvec/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
quidditch_module(SRC big_matvec.mlir ASSERT_XDSL)

add_executable(big_matvec_sample main.c)
target_link_libraries(
big_matvec_sample
PRIVATE
samples_util
big_matvec
snRuntime
Quidditch::dispatch::dispatch
)

126 changes: 126 additions & 0 deletions runtime/samples/big_matvec/big_matvec.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
builtin.module @big_matvec {
func.func @test32(%arg0: tensor<1x400xf64>, %arg1: tensor<320x400xf64>) -> tensor<1x320xf64> {
%init = tensor.empty() : tensor<1x320xf64>
%out = linalg.matmul_transpose_b {
lowering_config = #quidditch_snitch.lowering_config<
l1_tiles = [0, 32, 80],
l1_tiles_interchange = [2, 0, 1],
dual_buffer = true
>
}
ins(%arg0, %arg1 : tensor<1x400xf64>, tensor<320x400xf64>)
outs(%init : tensor<1x320xf64>) -> tensor<1x320xf64>
%out2 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>
], iterator_types = ["parallel", "parallel"]
} ins(%out : tensor<1x320xf64>) outs(%out : tensor<1x320xf64>) {
^bb0(%element : f64, %outs : f64):
%bias = arith.constant 5.0 : f64
%added = arith.addf %element, %bias : f64
linalg.yield %added : f64
} -> tensor<1x320xf64>
func.return %out2 : tensor<1x320xf64>
}

func.func @test40(%arg0: tensor<1x400xf64>, %arg1: tensor<320x400xf64>) -> tensor<1x320xf64> {
%init = tensor.empty() : tensor<1x320xf64>
%out = linalg.matmul_transpose_b {
lowering_config = #quidditch_snitch.lowering_config<
l1_tiles = [0, 40, 80],
l1_tiles_interchange = [2, 0, 1],
dual_buffer = true
>
}
ins(%arg0, %arg1 : tensor<1x400xf64>, tensor<320x400xf64>)
outs(%init : tensor<1x320xf64>) -> tensor<1x320xf64>
%out2 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>
], iterator_types = ["parallel", "parallel"]
} ins(%out : tensor<1x320xf64>) outs(%out : tensor<1x320xf64>) {
^bb0(%element : f64, %outs : f64):
%bias = arith.constant 5.0 : f64
%added = arith.addf %element, %bias : f64
linalg.yield %added : f64
} -> tensor<1x320xf64>
func.return %out2 : tensor<1x320xf64>
}

func.func @test64(%arg0: tensor<1x400xf64>, %arg1: tensor<320x400xf64>) -> tensor<1x320xf64> {
%init = tensor.empty() : tensor<1x320xf64>
%out = linalg.matmul_transpose_b {
lowering_config = #quidditch_snitch.lowering_config<
l1_tiles = [0, 64, 80],
l1_tiles_interchange = [2, 0, 1],
dual_buffer = true
>
}
ins(%arg0, %arg1 : tensor<1x400xf64>, tensor<320x400xf64>)
outs(%init : tensor<1x320xf64>) -> tensor<1x320xf64>
%out2 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>
], iterator_types = ["parallel", "parallel"]
} ins(%out : tensor<1x320xf64>) outs(%out : tensor<1x320xf64>) {
^bb0(%element : f64, %outs : f64):
%bias = arith.constant 5.0 : f64
%added = arith.addf %element, %bias : f64
linalg.yield %added : f64
} -> tensor<1x320xf64>
func.return %out2 : tensor<1x320xf64>
}

func.func @test32_100(%arg0: tensor<1x400xf64>, %arg1: tensor<320x400xf64>) -> tensor<1x320xf64> {
%init = tensor.empty() : tensor<1x320xf64>
%out = linalg.matmul_transpose_b {
lowering_config = #quidditch_snitch.lowering_config<
l1_tiles = [0, 32, 100],
l1_tiles_interchange = [2, 0, 1],
dual_buffer = true
>
}
ins(%arg0, %arg1 : tensor<1x400xf64>, tensor<320x400xf64>)
outs(%init : tensor<1x320xf64>) -> tensor<1x320xf64>
%out2 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>
], iterator_types = ["parallel", "parallel"]
} ins(%out : tensor<1x320xf64>) outs(%out : tensor<1x320xf64>) {
^bb0(%element : f64, %outs : f64):
%bias = arith.constant 5.0 : f64
%added = arith.addf %element, %bias : f64
linalg.yield %added : f64
} -> tensor<1x320xf64>
func.return %out2 : tensor<1x320xf64>
}

func.func @test40_100(%arg0: tensor<1x400xf64>, %arg1: tensor<320x400xf64>) -> tensor<1x320xf64> {
%init = tensor.empty() : tensor<1x320xf64>
%out = linalg.matmul_transpose_b {
lowering_config = #quidditch_snitch.lowering_config<
l1_tiles = [0, 40, 100],
l1_tiles_interchange = [2, 0, 1],
dual_buffer = true
>
}
ins(%arg0, %arg1 : tensor<1x400xf64>, tensor<320x400xf64>)
outs(%init : tensor<1x320xf64>) -> tensor<1x320xf64>
%out2 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>
], iterator_types = ["parallel", "parallel"]
} ins(%out : tensor<1x320xf64>) outs(%out : tensor<1x320xf64>) {
^bb0(%element : f64, %outs : f64):
%bias = arith.constant 5.0 : f64
%added = arith.addf %element, %bias : f64
linalg.yield %added : f64
} -> tensor<1x320xf64>
func.return %out2 : tensor<1x320xf64>
}
}
Loading

0 comments on commit 7466aa6

Please sign in to comment.