Skip to content

Commit

Permalink
Land a short-term fix for 128-bit stablehlo RNG
Browse files Browse the repository at this point in the history
We can use a 128-bit counter for philox number generation. In these
cases only load 64-bits of the key. This should work in the short
term, and would only impact cases where we generate more than 2^64
random numbers.
  • Loading branch information
rsuderman committed Aug 25, 2023
1 parent b0c77fa commit 49b9d65
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,16 @@ std::pair<Value, Value> extractKey32(OpBuilder &builder, Location loc,
return std::pair<Value, Value>(pair.first, pair.second);
}

// TODO(suderman): This gets the 128-bit counter to work however
// may not match XLA.
if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) {
Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 0);
Value state = builder.create<tensor::ExtractOp>(loc, store, idx1);
Value cast = builder.create<arith::BitcastOp>(loc, i64Ty, state);
auto pair = splitI64(ArithOpBuilder(builder, loc, cast));
return std::pair<Value, Value>(pair.first, pair.second);
}

return {nullptr, nullptr};
}

Expand All @@ -215,6 +225,15 @@ Value extractState64(OpBuilder &builder, Location loc, Value store) {
return cast;
}

// TODO(suderman): This gets the 128-bit counter to work however
// may not match XLA.
if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) {
Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 1);
Value state = builder.create<tensor::ExtractOp>(loc, store, idx1);
Value cast = builder.create<arith::BitcastOp>(loc, i64Ty, state);
return cast;
}

if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) {
Value idx2 = builder.create<arith::ConstantIndexOp>(loc, 2);
Value idx3 = builder.create<arith::ConstantIndexOp>(loc, 3);
Expand Down Expand Up @@ -244,6 +263,15 @@ Value setState64(OpBuilder &b, Location loc, Value store, Value state) {
ValueRange{idx1});
}

// TODO(suderman): This gets the 128-bit counter to work however
// may not match XLA.
if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) {
state = b.create<arith::BitcastOp>(loc, storeETy, state);
Value idx1 = b.create<arith::ConstantIndexOp>(loc, 1);
return b.create<tensor::InsertOp>(loc, storeTy, state, store,
ValueRange{idx1});
}

if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) {
Value idx2 = b.create<arith::ConstantIndexOp>(loc, 2);
Value idx3 = b.create<arith::ConstantIndexOp>(loc, 3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,15 @@ func.func @philox_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) {
// CHECK: return %[[INSERTED]], %[[COLLAPSE]]


// -----

// CHECK-LABEL: func.func @philox_128_i32
// CHECK-SAME: %[[ARG0:.*]]: tensor<3xi64>
func.func @philox_128_i32(%arg0: tensor<3xi64>) -> (tensor<3xi64>, tensor<8xi32>) {
%output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo<rng_algorithm PHILOX>} : (tensor<3xi64>) -> (tensor<3xi64>, tensor<8xi32>)
return %output_state, %output : tensor<3xi64>, tensor<8xi32>
}

// -----

func.func @philox_i32_odd(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) {
Expand Down

0 comments on commit 49b9d65

Please sign in to comment.