diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp index e7fc85fb6c27e..2fa4be192fe17 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgRandom.cpp @@ -196,6 +196,16 @@ std::pair extractKey32(OpBuilder &builder, Location loc, return std::pair(pair.first, pair.second); } + // TODO(suderman): This gets the 128-bit counter to work however + // may not match XLA. + if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { + Value idx1 = builder.create(loc, 0); + Value state = builder.create(loc, store, idx1); + Value cast = builder.create(loc, i64Ty, state); + auto pair = splitI64(ArithOpBuilder(builder, loc, cast)); + return std::pair(pair.first, pair.second); + } + return {nullptr, nullptr}; } @@ -215,6 +225,15 @@ Value extractState64(OpBuilder &builder, Location loc, Value store) { return cast; } + // TODO(suderman): This gets the 128-bit counter to work however + // may not match XLA. + if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { + Value idx1 = builder.create(loc, 1); + Value state = builder.create(loc, store, idx1); + Value cast = builder.create(loc, i64Ty, state); + return cast; + } + if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) { Value idx2 = builder.create(loc, 2); Value idx3 = builder.create(loc, 3); @@ -244,6 +263,15 @@ Value setState64(OpBuilder &b, Location loc, Value store, Value state) { ValueRange{idx1}); } + // TODO(suderman): This gets the 128-bit counter to work however + // may not match XLA. + if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { + state = b.create(loc, storeETy, state); + Value idx1 = b.create(loc, 1); + return b.create(loc, storeTy, state, store, + ValueRange{idx1}); + } + if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) { Value idx2 = b.create(loc, 2); Value idx3 = b.create(loc, 3); diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir index fae5d5dc98487..97fb254544d81 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_random.mlir @@ -525,6 +525,15 @@ func.func @philox_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) { // CHECK: return %[[INSERTED]], %[[COLLAPSE]] +// ----- + +// CHECK-LABEL: func.func @philox_128_i32 +// CHECK-SAME: %[[ARG0:.*]]: tensor<3xi64> +func.func @philox_128_i32(%arg0: tensor<3xi64>) -> (tensor<3xi64>, tensor<8xi32>) { + %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<3xi64>) -> (tensor<3xi64>, tensor<8xi32>) + return %output_state, %output : tensor<3xi64>, tensor<8xi32> +} + // ----- func.func @philox_i32_odd(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) {