Skip to content

Commit

Permalink
[xllvm] Add support for multi-result intrinsics (#1470)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsetoain authored May 21, 2024
1 parent 66e642b commit 92c6ed9
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 19 deletions.
13 changes: 13 additions & 0 deletions include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD

include "aie/Dialect/XLLVM/IR/XLLVM.td"
include "aie/Dialect/XLLVM/IR/XLLVMTypeConstraints.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

Expand Down Expand Up @@ -338,4 +339,16 @@ def VectorExtractElem32I512IntrOp :
I32:$idx,
I32:$sign)>;

// ----- MAX ELEMENT -----

def VectorMaxLtBf16IntrOp :
AIEVec2_IntrOp<"vmax.ltbf16",
[TypeIs<"res",
LLVM_StructOf<[
VectorOfLengthAndType<[32], [BF16]>,
I32]>
>], /*numResults=*/2>,
Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs,
VectorOfLengthAndType<[32], [BF16]>:$rhs)>;

#endif // AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD
46 changes: 46 additions & 0 deletions include/aie/Dialect/XLLVM/IR/XLLVMTypeConstraints.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//===- XLLVMTypeConstraints.td - XLLVM type constraints. --*- tablegen -*-====//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2024 Advanced Micro Devices, Inc.
//
//===----------------------------------------------------------------------===//
// Defines type constraints for LLVM types used in XLLVM intrinsic op
// definitions.
//===----------------------------------------------------------------------===//


#ifndef AIE_DIALECT_XLLVM_IR_XLLVMTYPECONSTRAINTS_TD
#define AIE_DIALECT_XLLVM_IR_XLLVMTYPECONSTRAINTS_TD

class EnumeratedType<Type ty, int idx> {
Type type = ty;
int index = idx;
}

class EnumerateTypeListFrom<list<Type> tlist, int from = 0> {
list<EnumeratedType> sequence =
!if(!empty(tlist), [],
!listconcat(
[EnumeratedType<!head(tlist), from>],
EnumerateTypeListFrom<!tail(tlist), !add(from, 1)>.sequence
));
}

class LLVM_StructOf<list<Type> structTypes> :
Type<
And<[LLVM_AnyStruct.predicate,
CPred<"cast<::mlir::LLVM::LLVMStructType>($_self).getBody().size() == " # !size(structTypes)>,
And<!foreach(enumTy, EnumerateTypeListFrom<structTypes>.sequence,
SubstLeaves<"$_self",
"cast<::mlir::LLVM::LLVMStructType>($_self).getBody()[" # enumTy.index # "]",
enumTy.type.predicate>
)
>
]>,
"an LLVM struct of {" # !interleave(!foreach(ty, structTypes, ty.summary), "; ") # "}"
>;

#endif // AIE_DIALECT_XLLVM_IR_XLLVMTYPECONSTRAINTS_TD
12 changes: 8 additions & 4 deletions lib/Dialect/XLLVM/XLLVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,16 @@ getNamedIntrinsicDeclaration(llvm::Module *M, llvm::StringRef fullName,
llvm::CallInst *createExternalLLVMIntrinsicCall(
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
Operation *intrOp, llvm::StringRef intrinsicName) {
// We support 0 or 1 results
assert(intrOp->getNumResults() <= 1 &&
"external multi-result intrinsics not supported");
llvm::Type *resTy = nullptr;
if (intrOp->getNumResults())
unsigned numResults = intrOp->getNumResults();
if (numResults == 1)
resTy = moduleTranslation.convertType(*(intrOp->getResultTypes().begin()));
else if (numResults > 1) {
SmallVector<llvm::Type *> resTys;
for (auto ty : intrOp->getResultTypes())
resTys.push_back(moduleTranslation.convertType(ty));
resTy = llvm::StructType::get(builder.getContext(), resTys);
}
auto operands = moduleTranslation.lookupValues(intrOp->getOperands());
SmallVector<llvm::Type *> types;
for (auto op : operands)
Expand Down
44 changes: 29 additions & 15 deletions test/Target/LLVMIR/aievec.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: aie-translate %s -mlir-to-llvmir | FileCheck %s
// RUN: aie-translate %s -mlir-to-llvmir -split-input-file | FileCheck %s

// ----- MAC -----
// -- MAC --

// CHECK-LABEL: define <16 x i64> @mac_conf_acc32
llvm.func @mac_conf_acc32(%A : vector<64xi8>,
Expand Down Expand Up @@ -30,7 +30,7 @@ llvm.func @mac_conf_bf16(%A : vector<32xbf16>,
llvm.return %0 : vector<8xi64>
}

// ----- MSC -----
// -- MSC --

// CHECK-LABEL: define <8 x i64> @msc_conf_bf16
llvm.func @msc_conf_bf16(%A : vector<32xbf16>,
Expand All @@ -46,7 +46,7 @@ llvm.func @msc_conf_bf16(%A : vector<32xbf16>,
llvm.return %0 : vector<8xi64>
}

// ----- MUL -----
// -- MUL --

// CHECK-LABEL: define <16 x i64> @mul_conf_acc32
llvm.func @mul_conf_acc32(%A : vector<64xi8>,
Expand Down Expand Up @@ -87,7 +87,7 @@ llvm.func @mul_conf_bf16(%A : vector<32xbf16>,
llvm.return %0 : vector<8xi64>
}

// ----- SET -----
// -- SET --

// CHECK-LABEL: define <16 x i32> @vector_set_128b_into_512b
llvm.func @vector_set_128b_into_512b(%v : vector<4xi32>) -> vector<16xi32> {
Expand All @@ -105,7 +105,7 @@ llvm.func @vector_set_256b_into_512b(%v : vector<8xi32>) -> vector<16xi32> {
llvm.return %1 : vector<16xi32>
}

// ----- SRS -----
// -- SRS --

// CHECK-LABEL: define <16 x i16> @srs_256b_v16_acc32
llvm.func @srs_256b_v16_acc32(%v : vector<8xi64>, %shft : i32, %sign : i32) -> vector<16xi16> {
Expand Down Expand Up @@ -169,7 +169,7 @@ llvm.func @srs_256b_v16_accfloat(%v : vector<8xi64>) -> vector<16xbf16> {
llvm.return %0 : vector<16xbf16>
}

// ----- BROADCAST -----
// -- BROADCAST --

// CHECK-LABEL: define <64 x i8> @vbroadcast8_i512
llvm.func @vbroadcast8_i512(%val : i32) -> vector<64xi8> {
Expand Down Expand Up @@ -211,7 +211,7 @@ llvm.func @vbroadcastfloat_i512(%val : f32) -> vector<16xf32> {
llvm.return %0 : vector<16xf32>
}

// ----- EXT -----
// -- EXT --

// CHECK-LABEL: define <8 x i32> @ext_i256_i512
llvm.func @ext_i256_i512(%v : vector<16xi32>, %idx : i32) -> vector<8xi32> {
Expand Down Expand Up @@ -249,7 +249,7 @@ llvm.func @ext_i128_i512(%v : vector<16xi32>) -> vector<4xi32> {
llvm.return %1 : vector<4xi32>
}

// ----- CONCAT -----
// -- CONCAT --

// CHECK-LABEL: define <16 x i32> @concat_i512_i256
llvm.func @concat_i512_i256(%a : vector<8xi32>, %b : vector<8xi32>) -> vector<16xi32> {
Expand Down Expand Up @@ -280,7 +280,7 @@ llvm.func @concat_i1024_i512(%a : vector<16xi32>, %b : vector<16xi32>) -> vector
llvm.return %0 : vector<32xi32>
}

// ----- SHUFFLE -----
// -- SHUFFLE --

// CHECK-LABEL: define <16 x i32> @shuffle_i512
llvm.func @shuffle_i512(%a : vector<16xi32>, %b : vector<16xi32>, %mode : i32) -> vector<16xi32> {
Expand All @@ -291,7 +291,7 @@ llvm.func @shuffle_i512(%a : vector<16xi32>, %b : vector<16xi32>, %mode : i32) -
llvm.return %0 : vector<16xi32>
}

// ----- UNDEF -----
// -- UNDEF --

// CHECK-LABEL: define <16 x i32> @undef_v16i32
llvm.func @undef_v16i32() -> vector<16xi32> {
Expand All @@ -300,7 +300,7 @@ llvm.func @undef_v16i32() -> vector<16xi32> {
llvm.return %0 : vector<16xi32>
}

// ----- UPD -----
// -- UPD --

// CHECK-LABEL: define <32 x bfloat> @upd_bf512_bf256
llvm.func @upd_bf512_bf256(%a : vector<32xbf16>, %b : vector<16xbf16>, %idx : i32) -> vector<32xbf16> {
Expand All @@ -310,7 +310,7 @@ llvm.func @upd_bf512_bf256(%a : vector<32xbf16>, %b : vector<16xbf16>, %idx : i3
llvm.return %0 : vector<32xbf16>
}

// ----- SHIFT -----
// -- SHIFT --

// CHECK-LABEL: define <16 x i32> @vshift_i512_i512
llvm.func @vshift_i512_i512(%a : vector<16xi32>, %b : vector<16xi32>, %step : i32, %shift : i32) -> vector<16xi32> {
Expand All @@ -330,7 +330,7 @@ llvm.func @vshift_bf512_bf512(%a : vector<32xbf16>, %b : vector<32xbf16>, %step
llvm.return %0 : vector<32xbf16>
}

// ----- EXTRACT ELEMENT -----
// -- EXTRACT ELEMENT --

// CHECK-LABEL: define i32 @vextract_elem8_i512
llvm.func @vextract_elem8_i512(%a : vector<64xi8>, %idx : i32, %sign : i32) -> i32 {
Expand All @@ -356,7 +356,7 @@ llvm.func @vextract_elem32_i512(%a : vector<16xi32>, %idx : i32, %sign : i32) ->
llvm.return %0 : i32
}

// ----- UPS -----
// -- UPS --

// CHECK-LABEL: define <8 x i64> @acc32_v16_i256_ups
llvm.func @acc32_v16_i256_ups(%v : vector<16xi16>, %shift : i32, %sign : i32) -> vector<8xi64> {
Expand Down Expand Up @@ -419,3 +419,17 @@ llvm.func @accfloat_v16_256b_ups(%v : vector<16xbf16>) -> vector<8xi64> {
%0 = "xllvm.intr.aie2.v16bf16.to.v16accfloat"(%v) : (vector<16xbf16>) -> vector<8xi64>
llvm.return %0 : vector<8xi64>
}

// -----

// CHECK-LABEL: <32 x bfloat> @vmax_ltbf16
llvm.func @vmax_ltbf16(%lhs: vector<32xbf16>, %rhs: vector<32xbf16>) -> vector<32xbf16> {
// CHECK: call { <32 x bfloat>, i32 } @llvm.aie2.vmax.ltbf16(
// CHECK-SAME: <32 x bfloat> %{{[0-9]+}}, <32 x bfloat> %{{[0-9]+}})
%0 = "xllvm.intr.aie2.vmax.ltbf16"(%lhs, %rhs) :
(vector<32xbf16>, vector<32xbf16>) -> !llvm.struct<(vector<32xbf16>, i32)>
%1 = llvm.extractvalue %0[0] : !llvm.struct<(vector<32xbf16>, i32)>
llvm.return %1 : vector<32xbf16>
}

// CHECK-LABEL: declare { <32 x bfloat>, i32 } @llvm.aie2.vmax.ltbf16(<32 x bfloat>, <32 x bfloat>)
10 changes: 10 additions & 0 deletions test/dialect/XLLVM/invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: aie-opt %s -split-input-file -verify-diagnostics

func.func @invalidStructType(%A : vector<32xbf16>, %B : vector<32xbf16>)
-> vector<16xbf16> {
// expected-error @+1 {{'res' is an LLVM struct of {vector of bfloat16 type values of length 32; 32-bit signless integer}}
%rs = "xllvm.intr.aie2.vmax.ltbf16"(%A, %B) :
(vector<32xbf16>, vector<32xbf16>) -> !llvm.struct<(vector<16xbf16>, i32)>
%rv = llvm.extractvalue %rs[0] : !llvm.struct<(vector<16xbf16>, i32)>
return %rv : vector<16xbf16>
}

0 comments on commit 92c6ed9

Please sign in to comment.