From 521277cd1a8044c5930608fd2f056d8ce5181b4f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 21 Nov 2023 15:31:38 +0300 Subject: [PATCH] [mlir][spirv] Add `CL.mix` op (#72800) --- .../mlir/Dialect/SPIRV/IR/SPIRVCLOps.td | 31 +++++++++++++++++++ mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 20 ++++++++++++ 2 files changed, 51 insertions(+) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td index c4900fb79f346f..026f59b2afd8e2 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td @@ -111,6 +111,37 @@ class SPIRV_CLTernaryArithmeticOp { + let summary = "Returns the linear blend of x & y implemented as: x + (y - x) * a"; + + let description = [{ + Result Type, x, y and a must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + Note: This instruction can be implemented using contractions such as mad + or fma. + + + + ``` + mix-op ::= ssa-id `=` `spirv.CL.mix` ssa-use, ssa-use, ssa-use `:` + float-scalar-vector-type + ``` + + #### Example: + + ```mlir + %0 = spirv.CL.mix %a, %b, %c : f32 + %1 = spirv.CL.mix %a, %b, %c : vector<3xf16> + ``` + }]; +} + // ----- def SPIRV_CLCeilOp : SPIRV_CLUnaryArithmeticOp<"ceil", 12, SPIRV_Float> { diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir index 29a4a46136156a..f4e3a83e39f247 100644 --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -188,6 +188,26 @@ func.func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () // ----- +//===----------------------------------------------------------------------===// +// spirv.CL.mix +//===----------------------------------------------------------------------===// + +func.func @mix(%a : f32, %b : f32, %c : f32) -> () { + // CHECK: spirv.CL.mix {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 + %2 = spirv.CL.mix %a, %b, %c : f32 + return +} + +// ----- + +func.func @mix(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () { + // CHECK: spirv.CL.mix {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32> + %2 = spirv.CL.mix %a, %b, %c : vector<3xf32> + return +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.CL.{F|S|U}{Max|Min} //===----------------------------------------------------------------------===//