Skip to content

Commit

Permalink
[Gemmini Dialect] Gemmini Dialect enhancement on tiled_matmul (#178)
Browse files Browse the repository at this point in the history
Add an activation function to Gemmini Dialect.
  • Loading branch information
Xinyu302 authored Oct 30, 2023
1 parent f356a61 commit ba241c1
Show file tree
Hide file tree
Showing 19 changed files with 1,452 additions and 122 deletions.
11 changes: 10 additions & 1 deletion backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
//
//===----------------------------------------------------------------------===//
let TargetPrefix = "riscv" in
def int_riscv_mvin : Intrinsic<[],[llvm_i64_ty, llvm_i64_ty],[]>;
def int_riscv_mvin : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>;

let TargetPrefix = "riscv" in
def int_riscv_mvin2 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>;

let TargetPrefix = "riscv" in
def int_riscv_mvin3 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>;

let TargetPrefix = "riscv" in
def int_riscv_mvout : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>;
Expand All @@ -35,6 +41,9 @@ def int_riscv_config_st : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>;
let TargetPrefix = "riscv" in
def int_riscv_config_ex : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>;

let TargetPrefix = "riscv" in
def int_riscv_config_norm : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>;

let TargetPrefix = "riscv" in
def int_riscv_preload : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>;

Expand Down
27 changes: 27 additions & 0 deletions backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ def MVIN : RVInstR<0b0000010, 0b011, OPC_CUSTOM_3, (outs),
let rd = 0;
}

let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in
def MVIN2 : RVInstR<0b0000001, 0b011, OPC_CUSTOM_3, (outs),
(ins GPR:$rs1, GPR:$rs2), "mvin2","$rs1, $rs2"> {
let rd = 0;
}

let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in
def MVIN3 : RVInstR<0b0001110, 0b011, OPC_CUSTOM_3, (outs),
(ins GPR:$rs1, GPR:$rs2), "mvin3","$rs1, $rs2"> {
let rd = 0;
}

let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in
def MVOUT : RVInstR<0b0000011, 0b011, OPC_CUSTOM_3, (outs),
(ins GPR:$rs1, GPR:$rs2), "mvout","$rs1, $rs2">{
Expand Down Expand Up @@ -65,6 +77,12 @@ def CONFIG_EX : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3,(outs),
let rd = 0;
}

let Predicates = [HasBuddyExt] in
def CONFIG_NORM : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3,(outs),
(ins GPR:$rs1, GPR:$rs2), "config_norm", "$rs1, $rs2"> {
let rd = 0;
}

let hasSideEffects = 1, mayLoad = 1, mayStore =1, Predicates = [HasBuddyExt] in
def PRELOAD : RVInstR<0b0000110, 0b011,OPC_CUSTOM_3,(outs),
(ins GPR:$rs1, GPR:$rs2), "preload", "$rs1, $rs2">{
Expand Down Expand Up @@ -164,6 +182,12 @@ def LOOP_CONV_WS_CONFIG6 : RVInstR<0b0010101, 0b011, OPC_CUSTOM_3, (outs),
let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_mvin GPR:$rs1, GPR:$rs2), (MVIN GPR:$rs1, GPR:$rs2)>;

let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_mvin2 GPR:$rs1, GPR:$rs2), (MVIN2 GPR:$rs1, GPR:$rs2)>;

let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_mvin3 GPR:$rs1, GPR:$rs2), (MVIN3 GPR:$rs1, GPR:$rs2)>;

let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_mvout GPR:$rs1, GPR:$rs2), (MVOUT GPR:$rs1, GPR:$rs2)>;

Expand All @@ -179,6 +203,9 @@ def : Pat<(int_riscv_config_st GPR:$rs1, GPR:$rs2), (CONFIG_ST GPR:$rs1, GPR:$rs
let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_config_ex GPR:$rs1, GPR:$rs2), (CONFIG_EX GPR:$rs1, GPR:$rs2)>;

let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_config_norm GPR:$rs1, GPR:$rs2), (CONFIG_NORM GPR:$rs1, GPR:$rs2)>;

let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_preload GPR:$rs1, GPR:$rs2), (PRELOAD GPR:$rs1, GPR:$rs2)>;

Expand Down
24 changes: 12 additions & 12 deletions examples/GemminiDialect/ciface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -127,53 +127,53 @@ func.func @linalg_conv6(%arg0 : memref<1x1x256x256xi8>, %arg1 : memref<1x1x13x13
func.func @gemmini_conv1(%input: memref<1x256x256x1xi8>, %weights: memref<9x1xi8>, %bias: memref<1xi32>, %output: memref<64516x1xi8>) {
%outdim = arith.constant 254 : i64
%kernelDim = arith.constant 3 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<9x1xi8> memref<1xi32> memref<64516x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<9x1xi8> memref<1xi32> memref<64516x1xi8> i64 i64 i64
return
}

// CHECK: llvm.func @_mlir_ciface_gemmini_conv2
func.func @gemmini_conv2(%input: memref<1x256x256x1xi8>, %weights: memref<25x1xi8>, %bias: memref<1xi32>, %output: memref<63504x1xi8>) {
%outdim = arith.constant 252 : i64
%kernelDim = arith.constant 5 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<25x1xi8> memref<1xi32> memref<63504x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<25x1xi8> memref<1xi32> memref<63504x1xi8> i64 i64 i64
return
}

// CHECK: llvm.func @_mlir_ciface_gemmini_conv3
func.func @gemmini_conv3(%input: memref<1x256x256x1xi8>, %weights: memref<49x1xi8>, %bias: memref<1xi32>, %output: memref<62500x1xi8>) {
%outdim = arith.constant 250 : i64
%kernelDim = arith.constant 7 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<49x1xi8> memref<1xi32> memref<62500x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<49x1xi8> memref<1xi32> memref<62500x1xi8> i64 i64 i64
return
}

// CHECK: llvm.func @_mlir_ciface_gemmini_conv4
func.func @gemmini_conv4(%input: memref<1x256x256x1xi8>, %weights: memref<81x1xi8>, %bias: memref<1xi32>, %output: memref<61504x1xi8>) {
%outdim = arith.constant 248 : i64
%kernelDim = arith.constant 9 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<81x1xi8> memref<1xi32> memref<61504x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<81x1xi8> memref<1xi32> memref<61504x1xi8> i64 i64 i64
return
}

// CHECK: llvm.func @_mlir_ciface_gemmini_conv5
func.func @gemmini_conv5(%input: memref<1x256x256x1xi8>, %weights: memref<121x1xi8>, %bias: memref<1xi32>, %output: memref<60516x1xi8>) {
%outdim = arith.constant 246 : i64
%kernelDim = arith.constant 11 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<121x1xi8> memref<1xi32> memref<60516x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<121x1xi8> memref<1xi32> memref<60516x1xi8> i64 i64 i64
return
}

// CHECK: llvm.func @_mlir_ciface_gemmini_conv6
func.func @gemmini_conv6(%input: memref<1x256x256x1xi8>, %weights: memref<169x1xi8>, %bias: memref<1xi32>, %output: memref<59536x1xi8>) {
%outdim = arith.constant 244 : i64
%kernelDim = arith.constant 13 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<169x1xi8> memref<1xi32> memref<59536x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<169x1xi8> memref<1xi32> memref<59536x1xi8> i64 i64 i64
return
}

Expand Down
90 changes: 90 additions & 0 deletions examples/GemminiDialect/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,51 @@ tile-matmul-run:
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-matmul-os-run:
@${BUDDY_OPT} ./tile-matmul-os.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-matmul-ws-igelu-run:
@${BUDDY_OPT} ./tile-matmul-ws-igelu.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-matmul-ws-relu-run:
@${BUDDY_OPT} ./tile-matmul-ws-relu.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-matmul-ws-softmax-run:
@${BUDDY_OPT} ./tile-matmul-ws-softmax.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-matmul-ws-layernorm-run:
@${BUDDY_OPT} ./tile-matmul-ws-layernorm.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-conv-run:
@${BUDDY_OPT} ./tile-conv.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
Expand All @@ -85,6 +130,51 @@ tile-conv-run:
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-conv-igelu-run:
@${BUDDY_OPT} ./tile-conv-igelu.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-conv-softmax-run:
@${BUDDY_OPT} ./tile-conv-softmax.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-conv-relu-run:
@${BUDDY_OPT} ./tile-conv-relu.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-conv-layernorm-run:
@${BUDDY_OPT} ./tile-conv-layernorm.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-rect-conv-run:
@${BUDDY_OPT} ./tile-rect-conv.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

gemmini-linalg-matmul-run:
@${BUDDY_OPT} ./matmul.mlir \
-convert-linalg-to-gemmini \
Expand Down
52 changes: 52 additions & 0 deletions examples/GemminiDialect/tile-conv-igelu.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: buddy-opt %s \
// RUN: --lower-gemmini | \
// RUN: FileCheck %s

// batchSize = 1 inputDim = 5 inChannels = 1
memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]]]]>

// outChannels = 2 kernelDim = 3 inChannels = 1
memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2],
[-1, 2], [-1, 2], [-1, 2],
[-1, 2], [-1, 2], [-1, 2]]>

// outChannels = 2
memref.global "private" @bias : memref<2xi32> = dense<[1,1]>

func.func @main() -> i64 {
%0 = arith.constant 0 : i64
%3 = arith.constant 3 : i64
%input = memref.get_global @input : memref<1x5x5x1xi8>
%weight = memref.get_global @weight : memref<9x2xi8>
%bias = memref.get_global @bias : memref<2xi32>
%output = memref.alloc() : memref<9x2xi8>

// CHECK: "gemmini.intr.loop_conv_ws_config1"
// CHECK: "gemmini.intr.loop_conv_ws_config2"
// CHECK: "gemmini.intr.loop_conv_ws_config3"
// CHECK: "gemmini.intr.loop_conv_ws_config4"
// CHECK: "gemmini.intr.loop_conv_ws_config5"
// CHECK: "gemmini.intr.loop_conv_ws_config6"
// CHECK: "gemmini.intr.loop_conv_ws"
// CHECK: "gemmini.intr.flush"
gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}:
memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64
gemmini.print %output : memref<9x2xi8>

// CHECK: "gemmini.intr.loop_conv_ws_config1"
// CHECK: "gemmini.intr.loop_conv_ws_config2"
// CHECK: "gemmini.intr.loop_conv_ws_config3"
// CHECK: "gemmini.intr.loop_conv_ws_config4"
// CHECK: "gemmini.intr.loop_conv_ws_config5"
// CHECK: "gemmini.intr.loop_conv_ws_config6"
// CHECK: "gemmini.intr.loop_conv_ws"
// CHECK: "gemmini.intr.flush"
gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 3}:
memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64
gemmini.print %output : memref<9x2xi8>
return %0 : i64
}
52 changes: 52 additions & 0 deletions examples/GemminiDialect/tile-conv-layernorm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: buddy-opt %s \
// RUN: --lower-gemmini | \
// RUN: FileCheck %s

// batchSize = 1 inputDim = 5 inChannels = 1
memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]]]]>

// outChannels = 2 kernelDim = 3 inChannels = 1
memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2],
[-1, 2], [-1, 2], [-1, 2],
[-1, 2], [-1, 2], [-1, 2]]>

// outChannels = 2
memref.global "private" @bias : memref<2xi32> = dense<[1,1]>

func.func @main() -> i64 {
%0 = arith.constant 0 : i64
%3 = arith.constant 3 : i64
%input = memref.get_global @input : memref<1x5x5x1xi8>
%weight = memref.get_global @weight : memref<9x2xi8>
%bias = memref.get_global @bias : memref<2xi32>
%output = memref.alloc() : memref<9x2xi8>

// CHECK: "gemmini.intr.loop_conv_ws_config1"
// CHECK: "gemmini.intr.loop_conv_ws_config2"
// CHECK: "gemmini.intr.loop_conv_ws_config3"
// CHECK: "gemmini.intr.loop_conv_ws_config4"
// CHECK: "gemmini.intr.loop_conv_ws_config5"
// CHECK: "gemmini.intr.loop_conv_ws_config6"
// CHECK: "gemmini.intr.loop_conv_ws"
// CHECK: "gemmini.intr.flush"
gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}:
memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64
gemmini.print %output : memref<9x2xi8>

// CHECK: "gemmini.intr.loop_conv_ws_config1"
// CHECK: "gemmini.intr.loop_conv_ws_config2"
// CHECK: "gemmini.intr.loop_conv_ws_config3"
// CHECK: "gemmini.intr.loop_conv_ws_config4"
// CHECK: "gemmini.intr.loop_conv_ws_config5"
// CHECK: "gemmini.intr.loop_conv_ws_config6"
// CHECK: "gemmini.intr.loop_conv_ws"
// CHECK: "gemmini.intr.flush"
gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 2}:
memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64
gemmini.print %output : memref<9x2xi8>
return %0 : i64
}
Loading

0 comments on commit ba241c1

Please sign in to comment.