Skip to content

Commit

Permalink
[MLIR] EmitC: Add subscript operator
Browse files Browse the repository at this point in the history
Reviewers: TinaAMD

Reviewed By: TinaAMD

Pull Request: #118
  • Loading branch information
mgehre-amd authored Feb 28, 2024
1 parent d3c90a4 commit c938fb8
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 8 deletions.
32 changes: 32 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -565,4 +565,36 @@ def EmitC_IfOp : EmitC_Op<"if",
let hasCustomAssemblyFormat = 1;
}

def EmitC_SubscriptOp : EmitC_Op<"subscript",
[TypesMatchWith<"result type matches element type of 'array'",
"array", "result",
"::llvm::cast<ArrayType>($_self).getElementType()">]> {
let summary = "Array subscript operation";
let description = [{
With the `subscript` operation the subscript operator `[]` can be applied
to variables or arguments of array type.

Example:

```mlir
%i = index.constant 1
%j = index.constant 7
%0 = emitc.subscript %arg0[%i][%j] : (!emitc.array<4x8xf32>) -> f32
```
}];
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
Variadic<Index>:$indices);
let results = (outs AnyType:$result);

let builders = [
OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
}]>
];

let hasVerifier = 1;
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array)";
}


#endif // MLIR_DIALECT_EMITC_IR_EMITC
19 changes: 17 additions & 2 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ LogicalResult ApplyOp::verify() {
LogicalResult emitc::AssignOp::verify() {
Value variable = getVar();
Operation *variableDef = variable.getDefiningOp();
if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
if (!variableDef ||
!llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
return emitOpError() << "requires first operand (" << variable
<< ") to be a Variable";
<< ") to be a Variable or subscript";

Value value = getValue();
if (variable.getType() != value.getType())
Expand Down Expand Up @@ -530,6 +531,20 @@ LogicalResult emitc::VariableOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// SubscriptOp
//===----------------------------------------------------------------------===//

LogicalResult emitc::SubscriptOp::verify() {
if (getIndices().size() != (size_t)getArray().getType().getRank()) {
return emitOpError() << "requires number of indices ("
<< getIndices().size()
<< ") to match the rank of the array type ("
<< getArray().getType().getRank() << ")";
}
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 33 additions & 6 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ struct CppEmitter {
/// Return the existing or a new name for a Value.
StringRef getOrCreateName(Value val);

// Returns the textual representation of a subscript operation.
std::string getSubscriptName(emitc::SubscriptOp op);

/// Return the existing or a new label of a Block.
StringRef getOrCreateName(Block &block);

Expand Down Expand Up @@ -251,8 +254,7 @@ static LogicalResult printOperation(CppEmitter &emitter,

static LogicalResult printOperation(CppEmitter &emitter,
emitc::AssignOp assignOp) {
auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
OpResult result = variableOp->getResult(0);
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);

if (failed(emitter.emitVariableAssignment(result)))
return failure();
Expand All @@ -262,6 +264,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::SubscriptOp subscriptOp) {
// Add name to cache so that `hasValueInScope` works.
emitter.getOrCreateName(subscriptOp.getResult());
return success();
}

static LogicalResult printBinaryOperation(CppEmitter &emitter,
Operation *operation,
StringRef binaryOperator) {
Expand Down Expand Up @@ -706,12 +715,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
labelInScopeCount.push(0);
}

std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
std::string out;
llvm::raw_string_ostream ss(out);
ss << getOrCreateName(op.getArray());
for (auto index : op.getIndices()) {
ss << "[" << getOrCreateName(index) << "]";
}
return out;
}

/// Return the existing or a new name for a Value.
StringRef CppEmitter::getOrCreateName(Value val) {
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
return literal.getValue();
if (!valueMapper.count(val))
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
if (!valueMapper.count(val)) {
if (auto subscript =
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
valueMapper.insert(val, getSubscriptName(subscript));
} else {
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
}
}
return *valueMapper.begin(val);
}

Expand Down Expand Up @@ -891,6 +916,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {

LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
bool trailingSemicolon) {
if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
return success();
if (hasValueInScope(result)) {
return result.getDefiningOp()->emitError(
"result variable for the operation already declared");
Expand Down Expand Up @@ -957,7 +984,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
emitc::ConstantOp, emitc::DivOp, emitc::ForOp, emitc::IfOp,
emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::SubOp,
emitc::VariableOp>(
emitc::SubscriptOp, emitc::VariableOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
Expand All @@ -973,7 +1000,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (failed(status))
return failure();

if (isa<emitc::LiteralOp>(op))
if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
return success();

os << (trailingSemicolon ? ";\n" : "\n");
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,11 @@ func.func @test_assign_type_mismatch(%arg1: f32) {
emitc.assign %arg1 : f32 to %v : i32
return
}

// -----

func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
// expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
%0 = emitc.subscript %arg0[%arg2] : <4x8xf32>
return
}
8 changes: 8 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,11 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
}
return
}

func.func @test_subscript(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>,
%arg2: index, %arg3: index) {
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
emitc.assign %0 : f32 to %1 : f32
return
}
12 changes: 12 additions & 0 deletions mlir/test/Target/Cpp/subscript.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s

func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
emitc.assign %0 : f32 to %1 : f32
return
}
// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];

0 comments on commit c938fb8

Please sign in to comment.