Skip to content

Commit

Permalink
[XLA:GPU] Fuse loops with the same map that do not affect each other.…
Browse files Browse the repository at this point in the history
… This allows for more cse & inlining.

PiperOrigin-RevId: 684341569
  • Loading branch information
vwbaker authored and Google-ML-Automation committed Oct 10, 2024
1 parent 3064645 commit 80784a0
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 0 deletions.
84 changes: 84 additions & 0 deletions xla/service/gpu/fusions/transforms/fuse_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/MLIRContext.h"
Expand Down Expand Up @@ -109,6 +110,16 @@ bool IndicesAreEqualAndInjective(int64_t iv_count, mv::InsertOp insert,
return llvm::all_of(matched_indices, [](bool matched) { return matched; });
}

bool LoopDominatesLoop(LoopOp dominator /*lastloop*/, LoopOp dominatee) {
mlir::DominanceInfo dom;
return llvm::all_of(dominatee.getResults(), [&](Value result) {
return llvm::all_of(result.getUsers(), [&](Operation* user) {
return dom.properlyDominates(dominator, user,
/*enclosingOpOk*/ false);
});
});
}

// Fuse insert_loop and extract_loop into a single loop, and remove the
// vector.insert and vector.extract ops.
void FuseExtractInsertLoopPair(MLIRContext* mlir_context, LoopOp insert_loop,
Expand Down Expand Up @@ -176,6 +187,68 @@ void FuseExtractInsertLoopPair(MLIRContext* mlir_context, LoopOp insert_loop,
rewriter.eraseOp(extract_loop);
}

// Fuse loops that have the same map, same dim variables, & can be rewritten as
// a single loop, each stacked on top of the next.
void FuseIndependentLoops(MLIRContext* mlir_context,
SmallVector<LoopOp>& loops) {
auto last_loop = loops.back();
auto map = last_loop.getIndexingMap();
mlir::IRRewriter rewriter(mlir_context);
rewriter.setInsertionPointAfter(last_loop);

SmallVector<Value> inits;
SmallVector<Value> results;
for (auto loop : loops) {
inits.append(loop.getInits().begin(), loop.getInits().end());
auto yield_op = loop.getBody()->getTerminator();
auto yields = yield_op->getOperands();
results.append(yields.begin(), yields.end());
yield_op->erase();
}
auto new_loop = rewriter.create<LoopOp>(last_loop.getLoc(), map,
last_loop.getDims(), inits);

auto new_args = new_loop.getRegion().front().getArguments();
int common_args_count = map.GetRangeVarsCount() + map.GetNumResults();
auto common_args = new_args.take_front(common_args_count);
auto init_args = new_args.drop_front(common_args_count);
auto new_results = new_loop.getResults();

for (auto loop : loops) {
int num_results = loop.getNumResults();
loop->replaceAllUsesWith(new_results.take_front(num_results));
new_results = new_results.drop_front(num_results);
SmallVector<Value> old_args(common_args);
auto old_inits = init_args.take_front(num_results);
old_args.append(old_inits.begin(), old_inits.end());
init_args = init_args.drop_front(num_results);

rewriter.mergeBlocks(&loop.getRegion().front(),
&new_loop.getRegion().front(), old_args);
rewriter.eraseOp(loop);
}
rewriter.setInsertionPointToEnd(new_loop.getBody());
rewriter.create<YieldOp>(new_loop.getLoc(), results);
}

void FuseSameMapLoopsIfPossible(MLIRContext* mlir_context,
SmallVector<LoopOp>& loops) {
if (loops.size() < 2) return;
auto last_loop = loops.back();
loops.pop_back();
SmallVector<LoopOp> eligible_loops;
for (auto loop : loops) {
if (LoopDominatesLoop(/*dominator=*/last_loop, /*dominatee=*/loop) &&
LoopsUseSameDimOps(last_loop, loop)) {
eligible_loops.push_back(loop);
}
}
eligible_loops.push_back(last_loop);

if (eligible_loops.size() < 2) return;
FuseIndependentLoops(mlir_context, eligible_loops);
}

void FuseExtractIfPossible(MLIRContext* mlir_context, mv::ExtractOp extract) {
// Check that it has the following pattern:
// %insert_loop = { %insert = vector.insert ... }
Expand Down Expand Up @@ -228,6 +301,17 @@ struct FuseLoopsPass : public impl::FuseLoopsPassBase<FuseLoopsPass> {
for (auto extract : extracts) {
FuseExtractIfPossible(mlir_context, extract);
}

// Fuse loops with the same map & that do not affect each other.
mlir::DenseMap<mlir::Attribute, SmallVector<LoopOp>> loops_by_map;
getOperation()->walk([&](Operation* op) -> void {
if (auto loop = mlir::dyn_cast<LoopOp>(op)) {
loops_by_map[loop.getIndexingMapAttr()].push_back(loop);
}
});
for (auto [_, loops] : loops_by_map) {
FuseSameMapLoopsIfPossible(mlir_context, loops);
}
}
};

Expand Down
93 changes: 93 additions & 0 deletions xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,96 @@ func.func @do_not_fuse_unused_loop_iv(%arg0: tensor<20x160x170xf32>) -> tensor<1
// CHECK: vector.insert
// CHECK: xla_gpu.loop
// CHECK: vector.extract

// -----

#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->"
" ((d0 floordiv 32) * 8192 + d1 * 8 + s0 * 32768 + (d0 floordiv 4) mod 8,"
" d0 mod 4),"
" domain:"
" d0 in [0, 127], d1 in [0, 1023],"
" s0 in [0, 2], s1 in [0, 0]">
#indexing_map1 = #xla_gpu.indexing_map<"(d0) -> "
" ((d0 floordiv 4) mod 8192,"
" d0 mod 4),"
" domain:"
" d0 in [0, 98303]">
func.func @fuse_identical_independent_loops(%arg0: tensor<8192x4xf64>,
%arg1: tensor<98304x4xf64>, %arg2: tensor<98304x4xf64>) ->
tensor<98304x4xf64> {
%tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]}
%bid = gpu.block_id x {xla.range = [0 : index, 1023 : index]}
%cst_2 = arith.constant 0.50000000000000089 : f64
%cst = arith.constant 0 : index
%xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map
iter_args(%iter = %arg1) -> (tensor<98304x4xf64>) {
%0:2 = xla_gpu.apply_indexing #indexing_map1(%ra)
%extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64>
%3 = arith.mulf %extracted, %cst_2 : f64
%inserted = tensor.insert %3 into %iter[%ra, %rb] : tensor<98304x4xf64>
xla_gpu.yield %inserted : tensor<98304x4xf64>
}
%xla_loop_1 = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map
iter_args(%iter = %arg2) -> (tensor<98304x4xf64>) {
%0:2 = xla_gpu.apply_indexing #indexing_map1(%ra)
%extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64>
%inserted = tensor.insert %extracted into %iter[%ra, %rb] :
tensor<98304x4xf64>
xla_gpu.yield %inserted : tensor<98304x4xf64>
}
return %xla_loop_1 : tensor<98304x4xf64>
}

// CHECK-LABEL: @fuse_identical_independent_loops
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8192x4xf64>,
// CHECK-SAME: %[[ARG1:.*]]: tensor<98304x4xf64>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<98304x4xf64>)
// CHECK: %[[LOOP0:.*]], %[[LOOP1:.*]] = xla_gpu.loop
// CHECK-SAME: -> (%[[RA:.*]], %[[RB:.*]]) in
// CHECK-SAME: iter_args(%[[ITER0:.*]] = %[[ARG1]], %[[ITER1:.*]] = %[[ARG2]])
// CHECK: tensor.insert {{.*}} into %[[ITER0]][%[[RA]], %[[RB]]]
// CHECK: tensor.insert {{.*}} into %[[ITER1]][%[[RA]], %[[RB]]]
// CHECK: xla_gpu.yield {{.*}} : tensor<98304x4xf64>, tensor<98304x4xf64>

// -----

#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->"
" ((d0 floordiv 32) * 8192 + d1 * 8 + s0 * 32768 + (d0 floordiv 4) mod 8,"
" d0 mod 4),"
" domain:"
" d0 in [0, 127], d1 in [0, 1023],"
" s0 in [0, 2], s1 in [0, 0]">
#indexing_map1 = #xla_gpu.indexing_map<"(d0) -> "
" ((d0 floordiv 4) mod 8192,"
" d0 mod 4),"
" domain:"
" d0 in [0, 98303]">
func.func @do_not_fuse_dependent_loops(%arg0: tensor<8192x4xf64>,
%arg1: tensor<98304x4xf64>) -> tensor<98304x4xf64> {
%tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]}
%bid = gpu.block_id x {xla.range = [0 : index, 1023 : index]}
%cst_2 = arith.constant 0.50000000000000089 : f64
%cst = arith.constant 0 : index
%xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map
iter_args(%iter = %arg1) -> (tensor<98304x4xf64>) {
%0:2 = xla_gpu.apply_indexing #indexing_map1(%ra)
%extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64>
%3 = arith.mulf %extracted, %cst_2 : f64
%inserted = tensor.insert %3 into %iter[%ra, %rb] : tensor<98304x4xf64>
xla_gpu.yield %inserted : tensor<98304x4xf64>
}
%dependency = tensor.insert %cst_2 into %xla_loop[%cst, %cst] :
tensor<98304x4xf64>
%xla_loop_1 = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map
iter_args(%iter = %dependency) -> (tensor<98304x4xf64>) {
%0:2 = xla_gpu.apply_indexing #indexing_map1(%ra)
%extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64>
%inserted = tensor.insert %extracted into %iter[%ra, %rb] :
tensor<98304x4xf64>
xla_gpu.yield %inserted : tensor<98304x4xf64>
}
return %xla_loop_1 : tensor<98304x4xf64>
}

// CHECK-LABEL: @do_not_fuse_dependent_loops
// CHECK-COUNT-2: xla_gpu.loop

0 comments on commit 80784a0

Please sign in to comment.