Skip to content

Commit

Permalink
[MLIR][OpenMP] Introduce host_eval clause to omp.target
Browse files Browse the repository at this point in the history
This patch defines a map-like clause named `host_eval` used to capture host
values for use inside of target regions on restricted cases:
  - As `num_teams` or `thread_limit` of a nested `omp.target` operation.
  - As `num_threads` of a nested `omp.parallel` operation or as bounds or steps
of a nested `omp.loop_nest`, if it is a target SPMD kernel.

This replaces the following `omp.target` arguments: `trip_count`,
`num_threads`, `num_teams_lower`, `num_teams_upper` and `teams_thread_limit`.
  • Loading branch information
skatrak committed Oct 10, 2024
1 parent c3518fb commit 44b6230
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 61 deletions.
56 changes: 55 additions & 1 deletion mlir/docs/Dialects/OpenMPDialect/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the
introduction of private copies of the same underlying variable defined outside
the MLIR operation the clause is attached to. Currently, clauses with this
property can be classified into three main categories:
- Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`.
- Map-like clauses: `host_eval`, `map`, `use_device_addr` and
`use_device_ptr`.
- Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
- Privatization clauses: `private`.

Expand Down Expand Up @@ -526,3 +527,56 @@ omp.parallel ... {
omp.terminator
} {omp.composite}
```

## Host-Evaluated Clauses in Target Regions

The `omp.target` operation, used to represent the OpenMP `target` construct, is
an `IsolatedFromAbove` operation, which means no outside MLIR values are allowed
inside of the region defined by it. This is a good match for the semantics of
the construct, since host values used inside of the target region must be
privatized or mapped to be used.

Regularly, the evaluation of clauses applied to a given construct must be
completed prior to entering that construct. However, there are clauses for which
the OpenMP specification defines exceptions when nested inside of a target
region. Specifically, the `num_teams` and `thread_limit` clauses of the `teams`
construct must be evaluated in the host if it is nested inside of or combined
with a `target` construct.

Additionally, the runtime library targeted by the MLIR to LLVM IR translation of
the OpenMP dialect supports the optimized launch of SPMD kernels (i.e.
`target teams distribute parallel {do,for}` in OpenMP), which requires
specifying in advance what the total trip count of the loop is. Consequently, it
is also beneficial to evaluate it in the host prior to the kernel launch.

These host-evaluated values in MLIR would need to be placed outside of the
`omp.target` region and also attached to the corresponding nested operations,
which is not possible because of the `IsolatedFromAbove` trait. The solution
implemented to address this problem has been to introduce the `host_eval`
argument to the `omp.target` operation. It works similarly to a `map` clause,
but its only intended use is to forward host-evaluated values to their
corresponding operation inside of the region. Any uses outside of the previously
described result in a verifier error.

```mlir
// Initialize %0, %1, %2, %3...
omp.target host_eval(%0 -> %nt, %1 -> %lb, %2 -> %ub, %3 -> %step : i32, i32, i32, i32) {
omp.teams num_teams(to %nt : i32) {
omp.parallel {
omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
// ...
omp.yield
}
omp.terminator
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
}
omp.terminator
}
```
38 changes: 38 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,44 @@ class OpenMP_HintClauseSkip<

def OpenMP_HintClause : OpenMP_HintClauseSkip<>;

//===----------------------------------------------------------------------===//
// Not in the spec: Clause-like structure to hold host-evaluated values.
//===----------------------------------------------------------------------===//

class OpenMP_HostEvalClauseSkip<
bit traits = false, bit arguments = false, bit assemblyFormat = false,
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
BlockArgOpenMPOpInterface
];

let arguments = (ins
Variadic<AnyType>:$host_eval_vars
);

let extraClassDeclaration = [{
unsigned numHostEvalBlockArgs() {
return getHostEvalVars().size();
}
}];

let description = [{
The optional `host_eval_vars` holds values defined outside of the region of
the `IsolatedFromAbove` operation for which a corresponding entry block
argument is defined. The only legal uses for these captured values are the
following:
- `num_teams` or `thread_limit` clause of an immediately nested
`omp.teams` operation.
- If the operation is the top-level `omp.target` of a target SPMD kernel:
- `num_threads` clause of the nested `omp.parallel` operation.
- Bounds and steps of the nested `omp.loop_nest` operation.
}];
}

def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>;

//===----------------------------------------------------------------------===//
// V5.2: [3.4] `if` clause
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 10 additions & 21 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1100,20 +1100,16 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
// 2.14.5 target construct
//===----------------------------------------------------------------------===//

// TODO: Remove num_threads, teams_thread_limit and trip_count and implement the
// passthrough approach described here:
// https://discourse.llvm.org/t/rfc-openmp-dialect-representation-of-num-teams-thread-limit-and-target-spmd/81106.
def TargetOp : OpenMP_Op<"target", traits = [
AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove,
OutlineableOpenMPOpInterface
], clauses = [
// TODO: Complete clause list (defaultmap, uses_allocators).
OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
OpenMP_HasDeviceAddrClause, OpenMP_IfClause, OpenMP_InReductionClause,
OpenMP_IsDevicePtrClause, OpenMP_MapClauseSkip<assemblyFormat = true>,
OpenMP_NowaitClause, OpenMP_NumTeamsClauseSkip<description = true>,
OpenMP_NumThreadsClauseSkip<description = true>, OpenMP_PrivateClause,
OpenMP_ThreadLimitClause
OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause, OpenMP_IfClause,
OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
OpenMP_PrivateClause, OpenMP_ThreadLimitClause
], singleRegion = true> {
let summary = "target construct";
let description = [{
Expand All @@ -1140,10 +1136,6 @@ def TargetOp : OpenMP_Op<"target", traits = [
an `omp.parallel`.
}] # clausesDescription;

let arguments = !con(clausesArgs,
(ins Optional<AnyInteger>:$trip_count,
Optional<AnyInteger>:$teams_thread_limit));

let builders = [
OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
];
Expand All @@ -1168,15 +1160,12 @@ def TargetOp : OpenMP_Op<"target", traits = [
bool isTargetSPMDLoop();
}] # clausesExtraClassDeclaration;

let assemblyFormat = clausesReqAssemblyFormat #
" oilist(" # clausesOptAssemblyFormat # [{
| `trip_count` `(` $trip_count `:` type($trip_count) `)`
| `teams_thread_limit` `(` $teams_thread_limit `:` type($teams_thread_limit) `)`
}] # ")" # [{
custom<InReductionMapPrivateRegion>(
$region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
$private_vars, type($private_vars), $private_syms) attr-dict
let assemblyFormat = clausesAssemblyFormat # [{
custom<HostEvalInReductionMapPrivateRegion>(
$region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
$map_vars, type($map_vars), $private_vars, type($private_vars),
$private_syms) attr-dict
}];

let hasVerifier = 1;
Expand Down
27 changes: 22 additions & 5 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {

let methods = [
// Default-implemented methods to be overriden by the corresponding clauses.
InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
"unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
return 0;
}]>,
InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
"unsigned", "numInReductionBlockArgs", (ins), [{}], [{
return 0;
Expand Down Expand Up @@ -55,9 +59,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
}]>,

// Unified access methods for clause-associated entry block arguments.
InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
"unsigned", "getHostEvalBlockArgsStart", (ins), [{
return 0;
}]>,
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
"unsigned", "getInReductionBlockArgsStart", (ins), [{
return 0;
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
}]>,
InterfaceMethod<"Get start index of block arguments defined by `map`.",
"unsigned", "getMapBlockArgsStart", (ins), [{
Expand Down Expand Up @@ -91,6 +100,13 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
}]>,

InterfaceMethod<"Get block arguments defined by `host_eval`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getHostEvalBlockArgs", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return $_op->getRegion(0).getArguments().slice(
iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
}]>,
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getInReductionBlockArgs", (ins), [{
Expand Down Expand Up @@ -147,10 +163,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {

let verify = [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
unsigned expectedArgs = iface.numInReductionBlockArgs() +
iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
unsigned expectedArgs = iface.numHostEvalBlockArgs() +
iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
iface.numUseDevicePtrBlockArgs();
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
return $_op->emitOpError() << "expected at least " << expectedArgs
<< " entry block argument(s)";
Expand Down
82 changes: 50 additions & 32 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ struct ReductionParseArgs {
: vars(vars), types(types), byref(byref), syms(syms) {}
};
struct AllRegionParseArgs {
std::optional<MapParseArgs> hostEvalArgs;
std::optional<ReductionParseArgs> inReductionArgs;
std::optional<MapParseArgs> mapArgs;
std::optional<PrivateParseArgs> privateArgs;
Expand Down Expand Up @@ -624,6 +625,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
AllRegionParseArgs args) {
llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;

if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
args.hostEvalArgs)))
return parser.emitError(parser.getCurrentLocation())
<< "invalid `host_eval` format";

if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
args.inReductionArgs)))
return parser.emitError(parser.getCurrentLocation())
Expand Down Expand Up @@ -662,8 +668,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
return parser.parseRegion(region, entryBlockArgs);
}

static ParseResult parseInReductionMapPrivateRegion(
static ParseResult parseHostEvalInReductionMapPrivateRegion(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
SmallVectorImpl<Type> &hostEvalTypes,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
SmallVectorImpl<Type> &inReductionTypes,
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
Expand All @@ -672,6 +680,7 @@ static ParseResult parseInReductionMapPrivateRegion(
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
AllRegionParseArgs args;
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
Expand Down Expand Up @@ -785,6 +794,7 @@ struct ReductionPrintArgs {
: vars(vars), types(types), byref(byref), syms(syms) {}
};
struct AllRegionPrintArgs {
std::optional<MapPrintArgs> hostEvalArgs;
std::optional<ReductionPrintArgs> inReductionArgs;
std::optional<MapPrintArgs> mapArgs;
std::optional<PrivatePrintArgs> privateArgs;
Expand Down Expand Up @@ -863,6 +873,8 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
MLIRContext *ctx = op->getContext();

printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
args.hostEvalArgs);
printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
args.inReductionArgs);
printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
Expand All @@ -883,12 +895,14 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
p.printRegion(region, /*printEntryBlockArgs=*/false);
}

static void printInReductionMapPrivateRegion(
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
static void printHostEvalInReductionMapPrivateRegion(
OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
TypeRange hostEvalTypes, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
AllRegionPrintArgs args;
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
Expand Down Expand Up @@ -966,6 +980,7 @@ static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op,
args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
printBlockArgRegion(p, op, region, args);
}

/// Verifies Reduction Clause
static LogicalResult
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
Expand Down Expand Up @@ -1651,14 +1666,12 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
// inReductionByref, inReductionSyms.
TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr,
clauses.device, clauses.hasDeviceAddrVars,
clauses.hostEvalVars, clauses.ifExpr,
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
clauses.mapVars, clauses.nowait, /*num_teams_lower=*/nullptr,
/*num_teams_upper=*/nullptr, /*num_threads_var=*/nullptr,
clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
clauses.threadLimit,
/*trip_count=*/nullptr, /*teams_thread_limit=*/nullptr);
clauses.mapVars, clauses.nowait, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
}

/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
Expand Down Expand Up @@ -1707,18 +1720,31 @@ LogicalResult TargetOp::verify() {
if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
return emitError("target containing multiple teams constructs");

if (!isTargetSPMDLoop() && getTripCount())
return emitError("trip_count set on non-SPMD target region");

if (teamsOps.empty()) {
if (getNumTeamsLower() || getNumTeamsUpper() || getTeamsThreadLimit())
return emitError(
"num_teams and teams_thread_limit arguments only allowed if there is "
"an omp.teams child operation");
} else {
if (failed(verifyNumTeamsClause(*this, getNumTeamsLower(),
getNumTeamsUpper())))
return failure();
// Check that host_eval values are only used in legal ways.
bool isTargetSPMD = isTargetSPMDLoop();
for (Value hostEvalArg :
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
for (Operation *user : hostEvalArg.getUsers()) {
if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
if (llvm::is_contained({teamsOp.getNumTeamsLower(),
teamsOp.getNumTeamsUpper(),
teamsOp.getThreadLimit()},
hostEvalArg))
continue;
} else if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads())
continue;
} else if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
if (isTargetSPMD &&
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg))) {
continue;
}
}
return emitOpError() << "host_eval argument illegal use in '"
<< user->getName() << "' operation";
}
}

LogicalResult verifyDependVars =
Expand Down Expand Up @@ -1948,17 +1974,9 @@ LogicalResult TeamsOp::verify() {
return emitError("expected to be nested inside of omp.target or not nested "
"in any OpenMP dialect operations");

auto offloadModOp =
llvm::cast<OffloadModuleInterface>(*(*this)->getParentOfType<ModuleOp>());
if (targetOp && !offloadModOp.getIsTargetDevice()) {
if (getNumTeamsLower() || getNumTeamsUpper() || getThreadLimit())
return emitError("num_teams and thread_limit arguments expected to be "
"attached to parent omp.target operation");
} else {
if (failed(verifyNumTeamsClause(*this, getNumTeamsLower(),
getNumTeamsUpper())))
return failure();
}
if (failed(
verifyNumTeamsClause(*this, getNumTeamsLower(), getNumTeamsUpper())))
return failure();

// Check for allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
Expand Down
Loading

0 comments on commit 44b6230

Please sign in to comment.