diff --git a/frontends/torch-frontend/examples/inference/mixtral/README.md b/frontends/torch-frontend/examples/inference/mixtral/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/frontends/torch-frontend/examples/inference/mixtral/infer_single_mixtral.py b/frontends/torch-frontend/examples/inference/mixtral/infer_single_mixtral.py index 380d6d473..2d654de07 100644 --- a/frontends/torch-frontend/examples/inference/mixtral/infer_single_mixtral.py +++ b/frontends/torch-frontend/examples/inference/mixtral/infer_single_mixtral.py @@ -3,6 +3,7 @@ import torch from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import ShapeEnv from functorch.compile import aot_module, default_partition from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode from torch._decomp import ( @@ -18,53 +19,6 @@ ) from transformers.models.mixtral.configuration_mixtral import MixtralConfig - -def fake_export_whole_mixtral(): - model_conf = MixtralConfig() - with FakeTensorMode(): - mixtral = MixtralModel(model_conf) - # step 1: fake init - print(mixtral) - - """ - MixtralModel( - (embed_tokens): Embedding(32000, 4096) - (layers): ModuleList( - (0-31): 32 x MixtralDecoderLayer( - (self_attn): MixtralSdpaAttention( - (q_proj): Linear(in_features=4096, out_features=4096, bias=False) - (k_proj): Linear(in_features=4096, out_features=1024, bias=False) - (v_proj): Linear(in_features=4096, out_features=1024, bias=False) - (o_proj): Linear(in_features=4096, out_features=4096, bias=False) - (rotary_emb): MixtralRotaryEmbedding() - ) - (block_sparse_moe): MixtralSparseMoeBlock( - (gate): Linear(in_features=4096, out_features=8, bias=False) - (experts): ModuleList( - (0-7): 8 x MixtralBlockSparseTop2MLP( - (w1): Linear(in_features=4096, out_features=14336, bias=False) - (w2): Linear(in_features=14336, out_features=4096, bias=False) - (w3): Linear(in_features=4096, out_features=14336, bias=False) - (act_fn): SiLU() - ) - ) - ) - (input_layernorm): MixtralRMSNorm() - (post_attention_layernorm): MixtralRMSNorm() - ) - ) - (norm): MixtralRMSNorm() - ) - """ - - # step 2: torch.export - bsz = 5 - seq_len = 7 - vocab_size = 32000 - token_ids = torch.randint(0, vocab_size, (bsz, seq_len)) - exported_mixtral = torch.export.export(mixtral, (token_ids,)) - print(exported_mixtral) - def export_mixtral_decoding_layer(): torch._dynamo.config.capture_dynamic_output_shape_ops = True model_conf = MixtralConfig(hidden_size=32) @@ -79,5 +33,48 @@ def export_mixtral_decoding_layer(): module = torch_frontend.compile_dynamo_model(mixtral_decoder_layer, output_type="stablehlo") print(module.operation.get_asm()) + +def export_fake_whole_mixtral_model(): + torch._dynamo.config.capture_dynamic_output_shape_ops = True + mixtral_config = MixtralConfig(attn_implementation="eager") + bsz = 3 + seqlen = 1 + past_kv_seqlen = 7 + with FakeTensorMode(shape_env=ShapeEnv()): + mixtral_model = MixtralModel(mixtral_config) + input_ids = torch.randint(0, mixtral_config.vocab_size, (bsz, seqlen)) + + past_key_values = [] + for i in range(mixtral_config.num_hidden_layers): + past_key_values.append( + ( + torch.rand( + bsz, + mixtral_config.num_key_value_heads, + past_kv_seqlen, + mixtral_config.hidden_size + // mixtral_config.num_attention_heads, + ), + torch.rand( + bsz, + mixtral_config.num_key_value_heads, + past_kv_seqlen, + mixtral_config.hidden_size + // mixtral_config.num_attention_heads, + ), + ) + ) + exported_mixtral = torch.export.export( + mixtral_model, + (input_ids, None, None, past_key_values), + ) + import torch_frontend + module = torch_frontend.compile_dynamo_model( + exported_mixtral, + output_type="stablehlo", + ) + print(module.operation.get_asm()) + if __name__ == "__main__": + export_fake_whole_mixtral_model() export_mixtral_decoding_layer() diff --git a/frontends/torch-frontend/examples/inference/mixtral/mixtral.stablehlo.mlir b/frontends/torch-frontend/examples/inference/mixtral/mixtral.stablehlo.mlir new file mode 100644 index 000000000..5453caf49 --- /dev/null +++ b/frontends/torch-frontend/examples/inference/mixtral/mixtral.stablehlo.mlir @@ -0,0 +1,32271 @@ +module { + func.func @main(%arg0: tensor<131072x128xf32>, %arg1: tensor<131072x128xf32>, %arg2: tensor<131072x128xf32>, %arg3: tensor<131072x128xf32>, %arg4: tensor<131072x128xf32>, %arg5: tensor<131072x128xf32>, %arg6: tensor<131072x128xf32>, %arg7: tensor<131072x128xf32>, %arg8: tensor<131072x128xf32>, %arg9: tensor<131072x128xf32>, %arg10: tensor<131072x128xf32>, %arg11: tensor<131072x128xf32>, %arg12: tensor<131072x128xf32>, %arg13: tensor<131072x128xf32>, %arg14: tensor<131072x128xf32>, %arg15: tensor<131072x128xf32>, %arg16: tensor<131072x128xf32>, %arg17: tensor<131072x128xf32>, %arg18: tensor<131072x128xf32>, %arg19: tensor<131072x128xf32>, %arg20: tensor<131072x128xf32>, %arg21: tensor<131072x128xf32>, %arg22: tensor<131072x128xf32>, %arg23: tensor<131072x128xf32>, %arg24: tensor<131072x128xf32>, %arg25: tensor<131072x128xf32>, %arg26: tensor<131072x128xf32>, %arg27: tensor<131072x128xf32>, %arg28: tensor<131072x128xf32>, %arg29: tensor<131072x128xf32>, %arg30: tensor<131072x128xf32>, %arg31: tensor<131072x128xf32>, %arg32: tensor<131072x128xf32>, %arg33: tensor<131072x128xf32>, %arg34: tensor<131072x128xf32>, %arg35: tensor<131072x128xf32>, %arg36: tensor<131072x128xf32>, %arg37: tensor<131072x128xf32>, %arg38: tensor<131072x128xf32>, %arg39: tensor<131072x128xf32>, %arg40: tensor<131072x128xf32>, %arg41: tensor<131072x128xf32>, %arg42: tensor<131072x128xf32>, %arg43: tensor<131072x128xf32>, %arg44: tensor<131072x128xf32>, %arg45: tensor<131072x128xf32>, %arg46: tensor<131072x128xf32>, %arg47: tensor<131072x128xf32>, %arg48: tensor<131072x128xf32>, %arg49: tensor<131072x128xf32>, %arg50: tensor<131072x128xf32>, %arg51: tensor<131072x128xf32>, %arg52: tensor<131072x128xf32>, %arg53: tensor<131072x128xf32>, %arg54: tensor<131072x128xf32>, %arg55: tensor<131072x128xf32>, %arg56: tensor<131072x128xf32>, %arg57: tensor<131072x128xf32>, %arg58: tensor<131072x128xf32>, %arg59: tensor<131072x128xf32>, %arg60: tensor<131072x128xf32>, %arg61: tensor<131072x128xf32>, %arg62: tensor<131072x128xf32>, %arg63: tensor<131072x128xf32>, %arg64: tensor<3x1xi64>, %arg65: tensor<3x8x7x128xf32>, %arg66: tensor<3x8x7x128xf32>, %arg67: tensor<3x8x7x128xf32>, %arg68: tensor<3x8x7x128xf32>, %arg69: tensor<3x8x7x128xf32>, %arg70: tensor<3x8x7x128xf32>, %arg71: tensor<3x8x7x128xf32>, %arg72: tensor<3x8x7x128xf32>, %arg73: tensor<3x8x7x128xf32>, %arg74: tensor<3x8x7x128xf32>, %arg75: tensor<3x8x7x128xf32>, %arg76: tensor<3x8x7x128xf32>, %arg77: tensor<3x8x7x128xf32>, %arg78: tensor<3x8x7x128xf32>, %arg79: tensor<3x8x7x128xf32>, %arg80: tensor<3x8x7x128xf32>, %arg81: tensor<3x8x7x128xf32>, %arg82: tensor<3x8x7x128xf32>, %arg83: tensor<3x8x7x128xf32>, %arg84: tensor<3x8x7x128xf32>, %arg85: tensor<3x8x7x128xf32>, %arg86: tensor<3x8x7x128xf32>, %arg87: tensor<3x8x7x128xf32>, %arg88: tensor<3x8x7x128xf32>, %arg89: tensor<3x8x7x128xf32>, %arg90: tensor<3x8x7x128xf32>, %arg91: tensor<3x8x7x128xf32>, %arg92: tensor<3x8x7x128xf32>, %arg93: tensor<3x8x7x128xf32>, %arg94: tensor<3x8x7x128xf32>, %arg95: tensor<3x8x7x128xf32>, %arg96: tensor<3x8x7x128xf32>, %arg97: tensor<3x8x7x128xf32>, %arg98: tensor<3x8x7x128xf32>, %arg99: tensor<3x8x7x128xf32>, %arg100: tensor<3x8x7x128xf32>, %arg101: tensor<3x8x7x128xf32>, %arg102: tensor<3x8x7x128xf32>, %arg103: tensor<3x8x7x128xf32>, %arg104: tensor<3x8x7x128xf32>, %arg105: tensor<3x8x7x128xf32>, %arg106: tensor<3x8x7x128xf32>, %arg107: tensor<3x8x7x128xf32>, %arg108: tensor<3x8x7x128xf32>, %arg109: tensor<3x8x7x128xf32>, %arg110: tensor<3x8x7x128xf32>, %arg111: tensor<3x8x7x128xf32>, %arg112: tensor<3x8x7x128xf32>, %arg113: tensor<3x8x7x128xf32>, %arg114: tensor<3x8x7x128xf32>, %arg115: tensor<3x8x7x128xf32>, %arg116: tensor<3x8x7x128xf32>, %arg117: tensor<3x8x7x128xf32>, %arg118: tensor<3x8x7x128xf32>, %arg119: tensor<3x8x7x128xf32>, %arg120: tensor<3x8x7x128xf32>, %arg121: tensor<3x8x7x128xf32>, %arg122: tensor<3x8x7x128xf32>, %arg123: tensor<3x8x7x128xf32>, %arg124: tensor<3x8x7x128xf32>, %arg125: tensor<3x8x7x128xf32>, %arg126: tensor<3x8x7x128xf32>, %arg127: tensor<3x8x7x128xf32>, %arg128: tensor<3x8x7x128xf32>) -> (tensor<3x1x4096xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>) { + %cst = stablehlo.constant dense<1.000000e+00> : tensor<1xf32> + %cst_0 = stablehlo.constant dense<4.096000e+03> : tensor + %cst_1 = stablehlo.constant dense<8.000000e+00> : tensor + %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<3x4096xf32> + %cst_3 = stablehlo.constant dense<11.3137083> : tensor<1xf32> + %cst_4 = stablehlo.constant dense<9.99999974E-6> : tensor<1xf32> + %cst_5 = stablehlo.constant dense<4.096000e+03> : tensor<1xf32> + %cst_6 = stablehlo.constant dense<2.000000e+00> : tensor<1xf32> + %c = stablehlo.constant dense<0> : tensor<1xi64> + %c_7 = stablehlo.constant dense<0> : tensor<4096xi64> + %c_8 = stablehlo.constant dense<1> : tensor<4096xi64> + %c_9 = stablehlo.constant dense<0> : tensor<8xi64> + %c_10 = stablehlo.constant dense<1> : tensor<8xi64> + %c_11 = stablehlo.constant dense<7> : tensor<1xi64> + %c_12 = stablehlo.constant dense<1> : tensor<1xi64> + %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<4096x14336xf32> + %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<14336x4096xf32> + %cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<8x4096xf32> + %cst_16 = stablehlo.constant dense<0.000000e+00> : tensor<1024x4096xf32> + %cst_17 = stablehlo.constant dense<0.000000e+00> : tensor<4096x4096xf32> + %cst_18 = stablehlo.constant dense<0.000000e+00> : tensor<4096xf32> + %cst_19 = stablehlo.constant dense<0.000000e+00> : tensor<32000x4096xf32> + %c1_i64 = arith.constant 1 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst_20 = stablehlo.constant dense<0.000000e+00> : tensor + %c_21 = stablehlo.constant dense<1> : tensor<2xi64> + %c_22 = stablehlo.constant dense<0> : tensor<2xi64> + %c_23 = stablehlo.constant dense<[-1, 4096]> : tensor<2xi64> + %c2_i64 = arith.constant 2 : i64 + %c_24 = stablehlo.constant dense<[0, 1]> : tensor<2xi64> + %c4096_i64 = arith.constant 4096 : i64 + %c4096 = arith.constant 4096 : index + %cst_25 = stablehlo.constant dense<1.000000e+00> : tensor + %0 = stablehlo.divide %cst_25, %cst_25 : tensor + %1 = stablehlo.ceil %0 : tensor + %2 = stablehlo.convert %1 : (tensor) -> tensor + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi64> + %4 = stablehlo.dynamic_iota %3, dim = 0 : (tensor<1xi64>) -> tensor<1xi64> + %5 = stablehlo.broadcast_in_dim %4, dims = [0] : (tensor<1xi64>) -> tensor<1xi64> + %6 = stablehlo.multiply %5, %c_12 : tensor<1xi64> + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<1xi64>) -> tensor<1xi64> + %8 = stablehlo.add %7, %c_11 : tensor<1xi64> + %9 = stablehlo.reshape %8 : (tensor<1xi64>) -> tensor<1x1xi64> + %10 = stablehlo.reshape %9 : (tensor<1x1xi64>) -> tensor<1x1xi64> + %11 = "stablehlo.gather"(%cst_19, %arg64) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<32000x4096xf32>, tensor<3x1xi64>) -> tensor<3x1x4096xf32> + %12 = stablehlo.convert %11 : tensor<3x1x4096xf32> + %13 = stablehlo.reshape %cst_6 : (tensor<1xf32>) -> tensor + %14 = stablehlo.broadcast_in_dim %12, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %15 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<3x1x4096xf32> + %16 = stablehlo.power %14, %15 : tensor<3x1x4096xf32> + %17 = stablehlo.reduce(%16 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %18 = stablehlo.reshape %17 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %19 = stablehlo.reshape %cst_5 : (tensor<1xf32>) -> tensor + %20 = stablehlo.broadcast_in_dim %18, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %21 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<3x1x1xf32> + %22 = stablehlo.divide %20, %21 : tensor<3x1x1xf32> + %23 = stablehlo.reshape %cst_4 : (tensor<1xf32>) -> tensor + %24 = stablehlo.broadcast_in_dim %22, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %25 = stablehlo.broadcast_in_dim %23, dims = [] : (tensor) -> tensor<3x1x1xf32> + %26 = stablehlo.add %24, %25 : tensor<3x1x1xf32> + %27 = stablehlo.rsqrt %26 : tensor<3x1x1xf32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %29 = stablehlo.multiply %14, %28 : tensor<3x1x4096xf32> + %30 = stablehlo.broadcast_in_dim %29, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %31 = stablehlo.broadcast_in_dim %cst_18, dims = [2] : (tensor<4096xf32>) -> tensor<3x1x4096xf32> + %32 = stablehlo.multiply %30, %31 : tensor<3x1x4096xf32> + %33 = stablehlo.transpose %cst_17, dims = [1, 0] : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> + %34 = stablehlo.reshape %32 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %35 = stablehlo.dot %34, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %36 = stablehlo.reshape %35 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %37 = stablehlo.transpose %cst_16, dims = [1, 0] : (tensor<1024x4096xf32>) -> tensor<4096x1024xf32> + %38 = stablehlo.dot %34, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %39 = stablehlo.reshape %38 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %40 = stablehlo.reshape %36 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %41 = stablehlo.transpose %40, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %42 = stablehlo.reshape %39 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %43 = stablehlo.transpose %42, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %44 = stablehlo.slice %arg0 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %45 = stablehlo.slice %arg1 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %46 = stablehlo.reshape %10 : (tensor<1x1xi64>) -> tensor<1x1x1xi64> + %47 = "stablehlo.gather"(%44, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %48 = stablehlo.reshape %47 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %49 = "stablehlo.gather"(%45, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %50 = stablehlo.reshape %49 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %51 = stablehlo.broadcast_in_dim %41, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %52 = stablehlo.broadcast_in_dim %48, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %53 = stablehlo.multiply %51, %52 : tensor<3x32x1x128xf32> + %54 = stablehlo.slice %41 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %55 = stablehlo.slice %41 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %56 = stablehlo.negate %55 : tensor<3x32x1x64xf32> + %57 = stablehlo.concatenate %56, %54, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %58 = stablehlo.broadcast_in_dim %57, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %59 = stablehlo.broadcast_in_dim %50, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %60 = stablehlo.multiply %58, %59 : tensor<3x32x1x128xf32> + %61 = stablehlo.add %53, %60 : tensor<3x32x1x128xf32> + %62 = stablehlo.broadcast_in_dim %43, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %63 = stablehlo.broadcast_in_dim %48, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %64 = stablehlo.multiply %62, %63 : tensor<3x8x1x128xf32> + %65 = stablehlo.slice %43 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %66 = stablehlo.slice %43 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %67 = stablehlo.negate %66 : tensor<3x8x1x64xf32> + %68 = stablehlo.concatenate %67, %65, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %69 = stablehlo.broadcast_in_dim %68, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %70 = stablehlo.broadcast_in_dim %50, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %71 = stablehlo.multiply %69, %70 : tensor<3x8x1x128xf32> + %72 = stablehlo.add %64, %71 : tensor<3x8x1x128xf32> + %73 = stablehlo.concatenate %arg65, %72, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %74 = stablehlo.concatenate %arg66, %43, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %75 = stablehlo.reshape %73 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %76 = stablehlo.broadcast_in_dim %75, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %77 = stablehlo.reshape %76 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %78 = stablehlo.reshape %74 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %79 = stablehlo.broadcast_in_dim %78, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %80 = stablehlo.reshape %79 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %81 = stablehlo.transpose %77, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %82 = stablehlo.reshape %61 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %83 = stablehlo.reshape %81 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %84 = stablehlo.broadcast_in_dim %83, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %85 = stablehlo.dot_general %82, %84, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %86 = stablehlo.reshape %85 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %87 = stablehlo.reshape %cst_3 : (tensor<1xf32>) -> tensor + %88 = stablehlo.broadcast_in_dim %86, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %89 = stablehlo.broadcast_in_dim %87, dims = [] : (tensor) -> tensor<3x32x1x8xf32> + %90 = stablehlo.divide %88, %89 : tensor<3x32x1x8xf32> + %91 = stablehlo.custom_call @byteir.softmax(%90) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %92 = stablehlo.reshape %91 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %93 = stablehlo.reshape %80 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %94 = stablehlo.broadcast_in_dim %93, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %95 = stablehlo.dot_general %92, %94, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %96 = stablehlo.reshape %95 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %97 = stablehlo.transpose %96, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %98 = stablehlo.reshape %97 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %99 = stablehlo.reshape %98 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %100 = stablehlo.dot %99, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %101 = stablehlo.reshape %100 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %102 = stablehlo.add %12, %101 : tensor<3x1x4096xf32> + %103 = stablehlo.broadcast_in_dim %102, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %104 = stablehlo.power %103, %15 : tensor<3x1x4096xf32> + %105 = stablehlo.reduce(%104 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %106 = stablehlo.reshape %105 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %107 = stablehlo.broadcast_in_dim %106, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %108 = stablehlo.divide %107, %21 : tensor<3x1x1xf32> + %109 = stablehlo.broadcast_in_dim %108, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %110 = stablehlo.add %109, %25 : tensor<3x1x1xf32> + %111 = stablehlo.rsqrt %110 : tensor<3x1x1xf32> + %112 = stablehlo.broadcast_in_dim %111, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %113 = stablehlo.multiply %103, %112 : tensor<3x1x4096xf32> + %114 = stablehlo.broadcast_in_dim %113, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %115 = stablehlo.multiply %114, %31 : tensor<3x1x4096xf32> + %116 = stablehlo.reshape %115 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %117 = stablehlo.transpose %cst_15, dims = [1, 0] : (tensor<8x4096xf32>) -> tensor<4096x8xf32> + %118 = stablehlo.dot %116, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %119 = stablehlo.custom_call @byteir.softmax(%118) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %120:2 = stablehlo.custom_call @byteir.top_k(%119) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %121 = stablehlo.reduce(%120#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %122 = stablehlo.reshape %121 : (tensor<3xf32>) -> tensor<3x1xf32> + %123 = stablehlo.broadcast_in_dim %120#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %124 = stablehlo.broadcast_in_dim %122, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %125 = stablehlo.divide %123, %124 : tensor<3x2xf32> + %126 = stablehlo.divide %cst_1, %cst_25 : tensor + %127 = stablehlo.ceil %126 : tensor + %128 = stablehlo.convert %127 : (tensor) -> tensor + %129 = stablehlo.reshape %128 : (tensor) -> tensor<1xi64> + %130 = stablehlo.dynamic_iota %129, dim = 0 : (tensor<1xi64>) -> tensor<8xi64> + %131 = stablehlo.broadcast_in_dim %130, dims = [0] : (tensor<8xi64>) -> tensor<8xi64> + %132 = stablehlo.multiply %131, %c_10 : tensor<8xi64> + %133 = stablehlo.broadcast_in_dim %132, dims = [0] : (tensor<8xi64>) -> tensor<8xi64> + %134 = stablehlo.add %133, %c_9 : tensor<8xi64> + %135 = stablehlo.reshape %120#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %136 = stablehlo.broadcast_in_dim %135, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %137 = stablehlo.broadcast_in_dim %134, dims = [2] : (tensor<8xi64>) -> tensor<3x2x8xi64> + %138 = stablehlo.compare EQ, %136, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %139 = stablehlo.convert %138 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %140 = stablehlo.transpose %139, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %141 = stablehlo.slice %140 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %142 = stablehlo.reshape %141 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %143 = stablehlo.custom_call @byteir.non_zero(%142) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim = tensor.dim %143, %c0 : tensor + %144 = arith.index_cast %dim : index to i64 + %from_elements = tensor.from_elements %144, %c1_i64 : tensor<2xi64> + %145 = stablehlo.real_dynamic_slice %143, %c_22, %from_elements, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_26 = tensor.dim %145, %c0 : tensor + %146 = arith.index_cast %dim_26 : index to i64 + %from_elements_27 = tensor.from_elements %146 : tensor<1xi64> + %147 = stablehlo.dynamic_reshape %145, %from_elements_27 : (tensor, tensor<1xi64>) -> tensor + %from_elements_28 = tensor.from_elements %144, %c2_i64 : tensor<2xi64> + %148 = stablehlo.real_dynamic_slice %143, %c_24, %from_elements_28, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_29 = tensor.dim %148, %c0 : tensor + %149 = arith.index_cast %dim_29 : index to i64 + %from_elements_30 = tensor.from_elements %149 : tensor<1xi64> + %150 = stablehlo.dynamic_reshape %148, %from_elements_30 : (tensor, tensor<1xi64>) -> tensor + %151 = stablehlo.reshape %116 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %152 = stablehlo.divide %cst_0, %cst_25 : tensor + %153 = stablehlo.ceil %152 : tensor + %154 = stablehlo.convert %153 : (tensor) -> tensor + %155 = stablehlo.reshape %154 : (tensor) -> tensor<1xi64> + %156 = stablehlo.dynamic_iota %155, dim = 0 : (tensor<1xi64>) -> tensor<4096xi64> + %157 = stablehlo.broadcast_in_dim %156, dims = [0] : (tensor<4096xi64>) -> tensor<4096xi64> + %158 = stablehlo.multiply %157, %c_8 : tensor<4096xi64> + %159 = stablehlo.broadcast_in_dim %158, dims = [0] : (tensor<4096xi64>) -> tensor<4096xi64> + %160 = stablehlo.add %159, %c_7 : tensor<4096xi64> + %dim_31 = tensor.dim %150, %c0 : tensor + %161 = arith.index_cast %dim_31 : index to i64 + %from_elements_32 = tensor.from_elements %161, %c1_i64 : tensor<2xi64> + %162 = stablehlo.dynamic_reshape %150, %from_elements_32 : (tensor, tensor<2xi64>) -> tensor + %163 = stablehlo.divide %cst_25, %cst_25 : tensor + %164 = stablehlo.ceil %163 : tensor + %165 = stablehlo.convert %164 : (tensor) -> tensor + %166 = stablehlo.reshape %165 : (tensor) -> tensor<1xi64> + %167 = stablehlo.dynamic_iota %166, dim = 0 : (tensor<1xi64>) -> tensor<1xi64> + %168 = stablehlo.broadcast_in_dim %167, dims = [0] : (tensor<1xi64>) -> tensor<1xi64> + %169 = stablehlo.multiply %168, %c_12 : tensor<1xi64> + %170 = stablehlo.broadcast_in_dim %169, dims = [0] : (tensor<1xi64>) -> tensor<1xi64> + %171 = stablehlo.add %170, %c : tensor<1xi64> + %172 = stablehlo.reshape %171 : (tensor<1xi64>) -> tensor<1x1xi64> + %173 = stablehlo.reshape %172 : (tensor<1x1xi64>) -> tensor<1x1x1xi64> + %dim_33 = tensor.dim %162, %c0 : tensor + %174 = arith.index_cast %dim_33 : index to i64 + %from_elements_34 = tensor.from_elements %c1_i64, %174, %c4096_i64 : tensor<3xi64> + %175 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_34, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_35 = tensor.dim %175, %c1 : tensor<1x?x4096xi64> + %176 = arith.index_cast %dim_35 : index to i64 + %from_elements_36 = tensor.from_elements %c1_i64, %176, %c4096_i64, %c1_i64 : tensor<4xi64> + %177 = stablehlo.dynamic_reshape %175, %from_elements_36 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %178 = stablehlo.dynamic_broadcast_in_dim %162, %from_elements_34, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_37 = tensor.dim %178, %c1 : tensor<1x?x4096xi64> + %179 = arith.index_cast %dim_37 : index to i64 + %from_elements_38 = tensor.from_elements %c1_i64, %179, %c4096_i64, %c1_i64 : tensor<4xi64> + %180 = stablehlo.dynamic_reshape %178, %from_elements_38 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %181 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_34, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_39 = tensor.dim %181, %c1 : tensor<1x?x4096xi64> + %182 = arith.index_cast %dim_39 : index to i64 + %from_elements_40 = tensor.from_elements %c1_i64, %182, %c4096_i64, %c1_i64 : tensor<4xi64> + %183 = stablehlo.dynamic_reshape %181, %from_elements_40 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %184 = stablehlo.concatenate %177, %180, %183, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %185 = "stablehlo.gather"(%151, %184) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %186 = shape.shape_of %185 : tensor<1x?x4096xf32> -> tensor<3xindex> + %187 = shape.num_elements %186 : tensor<3xindex> -> index + %188 = stablehlo.compute_reshape_shape %187, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %189 = stablehlo.dynamic_reshape %185, %188 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %190 = stablehlo.transpose %cst_14, dims = [1, 0] : (tensor<14336x4096xf32>) -> tensor<4096x14336xf32> + %191 = stablehlo.dot %189, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %192 = stablehlo.logistic %191 : tensor + %193 = shape.shape_of %192 : tensor -> tensor<2xindex> + %194 = shape.shape_of %191 : tensor -> tensor<2xindex> + %195 = shape.cstr_broadcastable %193, %194 : tensor<2xindex>, tensor<2xindex> + %196 = shape.assuming %195 -> (tensor) { + %19688 = shape.broadcast %193, %194 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %192, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %191, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %197 = shape.shape_of %196 : tensor -> tensor<2xindex> + %198 = shape.cstr_broadcastable %197, %194 : tensor<2xindex>, tensor<2xindex> + %199 = shape.assuming %198 -> (tensor) { + %19688 = shape.broadcast %197, %194 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %196, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %191, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %200 = stablehlo.transpose %cst_13, dims = [1, 0] : (tensor<4096x14336xf32>) -> tensor<14336x4096xf32> + %201 = stablehlo.dot %199, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %202 = stablehlo.reshape %125 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_41 = tensor.dim %150, %c0 : tensor + %203 = arith.index_cast %dim_41 : index to i64 + %from_elements_42 = tensor.from_elements %203, %c1_i64 : tensor<2xi64> + %204 = stablehlo.dynamic_reshape %150, %from_elements_42 : (tensor, tensor<2xi64>) -> tensor + %dim_43 = tensor.dim %147, %c0 : tensor + %205 = arith.index_cast %dim_43 : index to i64 + %from_elements_44 = tensor.from_elements %205, %c1_i64 : tensor<2xi64> + %206 = stablehlo.dynamic_reshape %147, %from_elements_44 : (tensor, tensor<2xi64>) -> tensor + %207 = stablehlo.concatenate %204, %206, dim = 1 : (tensor, tensor) -> tensor + %208 = "stablehlo.gather"(%202, %207) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %209 = shape.shape_of %201 : tensor -> tensor<2xindex> + %210 = shape.shape_of %208 : tensor -> tensor<2xindex> + %211 = shape.cstr_broadcastable %209, %210 : tensor<2xindex>, tensor<2xindex> + %212 = shape.assuming %211 -> (tensor) { + %19688 = shape.broadcast %209, %210 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %201, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %208, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %213 = stablehlo.reshape %cst : (tensor<1xf32>) -> tensor + %214 = shape.shape_of %212 : tensor -> tensor<2xindex> + %215 = stablehlo.dynamic_broadcast_in_dim %212, %214, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %216 = stablehlo.dynamic_broadcast_in_dim %213, %214, dims = [] : (tensor, tensor<2xindex>) -> tensor + %217 = stablehlo.multiply %215, %216 : tensor + %dim_45 = tensor.dim %162, %c0 : tensor + %218 = arith.index_cast %dim_45 : index to i64 + %dim_46 = tensor.dim %212, %c0 : tensor + %219 = arith.index_cast %dim_46 : index to i64 + %220 = arith.maxsi %218, %219 : i64 + %221 = arith.index_cast %220 : i64 to index + %from_elements_47 = tensor.from_elements %221, %c4096 : tensor<2xindex> + %222 = stablehlo.dynamic_broadcast_in_dim %162, %from_elements_47, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_48 = tensor.dim %222, %c0 : tensor + %223 = arith.index_cast %dim_48 : index to i64 + %from_elements_49 = tensor.from_elements %223, %c4096_i64 : tensor<2xi64> + %224 = stablehlo.real_dynamic_slice %217, %c_22, %from_elements_49, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_50 = tensor.from_elements %223, %c4096_i64, %c1_i64 : tensor<3xi64> + %225 = stablehlo.dynamic_reshape %222, %from_elements_50 : (tensor, tensor<3xi64>) -> tensor + %226 = stablehlo.dynamic_iota %from_elements_50, dim = 1 : (tensor<3xi64>) -> tensor + %227 = stablehlo.concatenate %225, %226, dim = 2 : (tensor, tensor) -> tensor + %228 = "stablehlo.scatter"(%cst_2, %227, %224) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %229 = stablehlo.slice %140 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %230 = stablehlo.reshape %229 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %231 = stablehlo.custom_call @byteir.non_zero(%230) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_51 = tensor.dim %231, %c0 : tensor + %232 = arith.index_cast %dim_51 : index to i64 + %from_elements_52 = tensor.from_elements %232, %c1_i64 : tensor<2xi64> + %233 = stablehlo.real_dynamic_slice %231, %c_22, %from_elements_52, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_53 = tensor.dim %233, %c0 : tensor + %234 = arith.index_cast %dim_53 : index to i64 + %from_elements_54 = tensor.from_elements %234 : tensor<1xi64> + %235 = stablehlo.dynamic_reshape %233, %from_elements_54 : (tensor, tensor<1xi64>) -> tensor + %from_elements_55 = tensor.from_elements %232, %c2_i64 : tensor<2xi64> + %236 = stablehlo.real_dynamic_slice %231, %c_24, %from_elements_55, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_56 = tensor.dim %236, %c0 : tensor + %237 = arith.index_cast %dim_56 : index to i64 + %from_elements_57 = tensor.from_elements %237 : tensor<1xi64> + %238 = stablehlo.dynamic_reshape %236, %from_elements_57 : (tensor, tensor<1xi64>) -> tensor + %dim_58 = tensor.dim %238, %c0 : tensor + %239 = arith.index_cast %dim_58 : index to i64 + %from_elements_59 = tensor.from_elements %239, %c1_i64 : tensor<2xi64> + %240 = stablehlo.dynamic_reshape %238, %from_elements_59 : (tensor, tensor<2xi64>) -> tensor + %dim_60 = tensor.dim %240, %c0 : tensor + %241 = arith.index_cast %dim_60 : index to i64 + %from_elements_61 = tensor.from_elements %c1_i64, %241, %c4096_i64 : tensor<3xi64> + %242 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_61, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_62 = tensor.dim %242, %c1 : tensor<1x?x4096xi64> + %243 = arith.index_cast %dim_62 : index to i64 + %from_elements_63 = tensor.from_elements %c1_i64, %243, %c4096_i64, %c1_i64 : tensor<4xi64> + %244 = stablehlo.dynamic_reshape %242, %from_elements_63 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %245 = stablehlo.dynamic_broadcast_in_dim %240, %from_elements_61, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_64 = tensor.dim %245, %c1 : tensor<1x?x4096xi64> + %246 = arith.index_cast %dim_64 : index to i64 + %from_elements_65 = tensor.from_elements %c1_i64, %246, %c4096_i64, %c1_i64 : tensor<4xi64> + %247 = stablehlo.dynamic_reshape %245, %from_elements_65 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %248 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_61, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_66 = tensor.dim %248, %c1 : tensor<1x?x4096xi64> + %249 = arith.index_cast %dim_66 : index to i64 + %from_elements_67 = tensor.from_elements %c1_i64, %249, %c4096_i64, %c1_i64 : tensor<4xi64> + %250 = stablehlo.dynamic_reshape %248, %from_elements_67 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %251 = stablehlo.concatenate %244, %247, %250, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %252 = "stablehlo.gather"(%151, %251) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %253 = shape.shape_of %252 : tensor<1x?x4096xf32> -> tensor<3xindex> + %254 = shape.num_elements %253 : tensor<3xindex> -> index + %255 = stablehlo.compute_reshape_shape %254, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %256 = stablehlo.dynamic_reshape %252, %255 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %257 = stablehlo.dot %256, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %258 = stablehlo.logistic %257 : tensor + %259 = shape.shape_of %258 : tensor -> tensor<2xindex> + %260 = shape.shape_of %257 : tensor -> tensor<2xindex> + %261 = shape.cstr_broadcastable %259, %260 : tensor<2xindex>, tensor<2xindex> + %262 = shape.assuming %261 -> (tensor) { + %19688 = shape.broadcast %259, %260 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %258, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %257, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %263 = shape.shape_of %262 : tensor -> tensor<2xindex> + %264 = shape.cstr_broadcastable %263, %260 : tensor<2xindex>, tensor<2xindex> + %265 = shape.assuming %264 -> (tensor) { + %19688 = shape.broadcast %263, %260 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %262, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %257, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %266 = stablehlo.dot %265, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_68 = tensor.dim %238, %c0 : tensor + %267 = arith.index_cast %dim_68 : index to i64 + %from_elements_69 = tensor.from_elements %267, %c1_i64 : tensor<2xi64> + %268 = stablehlo.dynamic_reshape %238, %from_elements_69 : (tensor, tensor<2xi64>) -> tensor + %dim_70 = tensor.dim %235, %c0 : tensor + %269 = arith.index_cast %dim_70 : index to i64 + %from_elements_71 = tensor.from_elements %269, %c1_i64 : tensor<2xi64> + %270 = stablehlo.dynamic_reshape %235, %from_elements_71 : (tensor, tensor<2xi64>) -> tensor + %271 = stablehlo.concatenate %268, %270, dim = 1 : (tensor, tensor) -> tensor + %272 = "stablehlo.gather"(%202, %271) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %273 = shape.shape_of %266 : tensor -> tensor<2xindex> + %274 = shape.shape_of %272 : tensor -> tensor<2xindex> + %275 = shape.cstr_broadcastable %273, %274 : tensor<2xindex>, tensor<2xindex> + %276 = shape.assuming %275 -> (tensor) { + %19688 = shape.broadcast %273, %274 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %266, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %272, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %277 = shape.shape_of %276 : tensor -> tensor<2xindex> + %278 = stablehlo.dynamic_broadcast_in_dim %276, %277, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %279 = stablehlo.dynamic_broadcast_in_dim %213, %277, dims = [] : (tensor, tensor<2xindex>) -> tensor + %280 = stablehlo.multiply %278, %279 : tensor + %dim_72 = tensor.dim %240, %c0 : tensor + %281 = arith.index_cast %dim_72 : index to i64 + %dim_73 = tensor.dim %276, %c0 : tensor + %282 = arith.index_cast %dim_73 : index to i64 + %283 = arith.maxsi %281, %282 : i64 + %284 = arith.index_cast %283 : i64 to index + %from_elements_74 = tensor.from_elements %284, %c4096 : tensor<2xindex> + %285 = stablehlo.dynamic_broadcast_in_dim %240, %from_elements_74, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_75 = tensor.dim %285, %c0 : tensor + %286 = arith.index_cast %dim_75 : index to i64 + %from_elements_76 = tensor.from_elements %286, %c4096_i64 : tensor<2xi64> + %287 = stablehlo.real_dynamic_slice %280, %c_22, %from_elements_76, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_77 = tensor.from_elements %286, %c4096_i64, %c1_i64 : tensor<3xi64> + %288 = stablehlo.dynamic_reshape %285, %from_elements_77 : (tensor, tensor<3xi64>) -> tensor + %289 = stablehlo.dynamic_iota %from_elements_77, dim = 1 : (tensor<3xi64>) -> tensor + %290 = stablehlo.concatenate %288, %289, dim = 2 : (tensor, tensor) -> tensor + %291 = "stablehlo.scatter"(%228, %290, %287) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %292 = stablehlo.slice %140 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %293 = stablehlo.reshape %292 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %294 = stablehlo.custom_call @byteir.non_zero(%293) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_78 = tensor.dim %294, %c0 : tensor + %295 = arith.index_cast %dim_78 : index to i64 + %from_elements_79 = tensor.from_elements %295, %c1_i64 : tensor<2xi64> + %296 = stablehlo.real_dynamic_slice %294, %c_22, %from_elements_79, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_80 = tensor.dim %296, %c0 : tensor + %297 = arith.index_cast %dim_80 : index to i64 + %from_elements_81 = tensor.from_elements %297 : tensor<1xi64> + %298 = stablehlo.dynamic_reshape %296, %from_elements_81 : (tensor, tensor<1xi64>) -> tensor + %from_elements_82 = tensor.from_elements %295, %c2_i64 : tensor<2xi64> + %299 = stablehlo.real_dynamic_slice %294, %c_24, %from_elements_82, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_83 = tensor.dim %299, %c0 : tensor + %300 = arith.index_cast %dim_83 : index to i64 + %from_elements_84 = tensor.from_elements %300 : tensor<1xi64> + %301 = stablehlo.dynamic_reshape %299, %from_elements_84 : (tensor, tensor<1xi64>) -> tensor + %dim_85 = tensor.dim %301, %c0 : tensor + %302 = arith.index_cast %dim_85 : index to i64 + %from_elements_86 = tensor.from_elements %302, %c1_i64 : tensor<2xi64> + %303 = stablehlo.dynamic_reshape %301, %from_elements_86 : (tensor, tensor<2xi64>) -> tensor + %dim_87 = tensor.dim %303, %c0 : tensor + %304 = arith.index_cast %dim_87 : index to i64 + %from_elements_88 = tensor.from_elements %c1_i64, %304, %c4096_i64 : tensor<3xi64> + %305 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_88, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_89 = tensor.dim %305, %c1 : tensor<1x?x4096xi64> + %306 = arith.index_cast %dim_89 : index to i64 + %from_elements_90 = tensor.from_elements %c1_i64, %306, %c4096_i64, %c1_i64 : tensor<4xi64> + %307 = stablehlo.dynamic_reshape %305, %from_elements_90 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %308 = stablehlo.dynamic_broadcast_in_dim %303, %from_elements_88, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_91 = tensor.dim %308, %c1 : tensor<1x?x4096xi64> + %309 = arith.index_cast %dim_91 : index to i64 + %from_elements_92 = tensor.from_elements %c1_i64, %309, %c4096_i64, %c1_i64 : tensor<4xi64> + %310 = stablehlo.dynamic_reshape %308, %from_elements_92 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %311 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_88, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_93 = tensor.dim %311, %c1 : tensor<1x?x4096xi64> + %312 = arith.index_cast %dim_93 : index to i64 + %from_elements_94 = tensor.from_elements %c1_i64, %312, %c4096_i64, %c1_i64 : tensor<4xi64> + %313 = stablehlo.dynamic_reshape %311, %from_elements_94 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %314 = stablehlo.concatenate %307, %310, %313, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %315 = "stablehlo.gather"(%151, %314) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %316 = shape.shape_of %315 : tensor<1x?x4096xf32> -> tensor<3xindex> + %317 = shape.num_elements %316 : tensor<3xindex> -> index + %318 = stablehlo.compute_reshape_shape %317, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %319 = stablehlo.dynamic_reshape %315, %318 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %320 = stablehlo.dot %319, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %321 = stablehlo.logistic %320 : tensor + %322 = shape.shape_of %321 : tensor -> tensor<2xindex> + %323 = shape.shape_of %320 : tensor -> tensor<2xindex> + %324 = shape.cstr_broadcastable %322, %323 : tensor<2xindex>, tensor<2xindex> + %325 = shape.assuming %324 -> (tensor) { + %19688 = shape.broadcast %322, %323 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %321, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %320, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %326 = shape.shape_of %325 : tensor -> tensor<2xindex> + %327 = shape.cstr_broadcastable %326, %323 : tensor<2xindex>, tensor<2xindex> + %328 = shape.assuming %327 -> (tensor) { + %19688 = shape.broadcast %326, %323 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %325, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %320, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %329 = stablehlo.dot %328, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_95 = tensor.dim %301, %c0 : tensor + %330 = arith.index_cast %dim_95 : index to i64 + %from_elements_96 = tensor.from_elements %330, %c1_i64 : tensor<2xi64> + %331 = stablehlo.dynamic_reshape %301, %from_elements_96 : (tensor, tensor<2xi64>) -> tensor + %dim_97 = tensor.dim %298, %c0 : tensor + %332 = arith.index_cast %dim_97 : index to i64 + %from_elements_98 = tensor.from_elements %332, %c1_i64 : tensor<2xi64> + %333 = stablehlo.dynamic_reshape %298, %from_elements_98 : (tensor, tensor<2xi64>) -> tensor + %334 = stablehlo.concatenate %331, %333, dim = 1 : (tensor, tensor) -> tensor + %335 = "stablehlo.gather"(%202, %334) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %336 = shape.shape_of %329 : tensor -> tensor<2xindex> + %337 = shape.shape_of %335 : tensor -> tensor<2xindex> + %338 = shape.cstr_broadcastable %336, %337 : tensor<2xindex>, tensor<2xindex> + %339 = shape.assuming %338 -> (tensor) { + %19688 = shape.broadcast %336, %337 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %329, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %335, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %340 = shape.shape_of %339 : tensor -> tensor<2xindex> + %341 = stablehlo.dynamic_broadcast_in_dim %339, %340, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %342 = stablehlo.dynamic_broadcast_in_dim %213, %340, dims = [] : (tensor, tensor<2xindex>) -> tensor + %343 = stablehlo.multiply %341, %342 : tensor + %dim_99 = tensor.dim %303, %c0 : tensor + %344 = arith.index_cast %dim_99 : index to i64 + %dim_100 = tensor.dim %339, %c0 : tensor + %345 = arith.index_cast %dim_100 : index to i64 + %346 = arith.maxsi %344, %345 : i64 + %347 = arith.index_cast %346 : i64 to index + %from_elements_101 = tensor.from_elements %347, %c4096 : tensor<2xindex> + %348 = stablehlo.dynamic_broadcast_in_dim %303, %from_elements_101, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_102 = tensor.dim %348, %c0 : tensor + %349 = arith.index_cast %dim_102 : index to i64 + %from_elements_103 = tensor.from_elements %349, %c4096_i64 : tensor<2xi64> + %350 = stablehlo.real_dynamic_slice %343, %c_22, %from_elements_103, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_104 = tensor.from_elements %349, %c4096_i64, %c1_i64 : tensor<3xi64> + %351 = stablehlo.dynamic_reshape %348, %from_elements_104 : (tensor, tensor<3xi64>) -> tensor + %352 = stablehlo.dynamic_iota %from_elements_104, dim = 1 : (tensor<3xi64>) -> tensor + %353 = stablehlo.concatenate %351, %352, dim = 2 : (tensor, tensor) -> tensor + %354 = "stablehlo.scatter"(%291, %353, %350) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %355 = stablehlo.slice %140 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %356 = stablehlo.reshape %355 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %357 = stablehlo.custom_call @byteir.non_zero(%356) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_105 = tensor.dim %357, %c0 : tensor + %358 = arith.index_cast %dim_105 : index to i64 + %from_elements_106 = tensor.from_elements %358, %c1_i64 : tensor<2xi64> + %359 = stablehlo.real_dynamic_slice %357, %c_22, %from_elements_106, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_107 = tensor.dim %359, %c0 : tensor + %360 = arith.index_cast %dim_107 : index to i64 + %from_elements_108 = tensor.from_elements %360 : tensor<1xi64> + %361 = stablehlo.dynamic_reshape %359, %from_elements_108 : (tensor, tensor<1xi64>) -> tensor + %from_elements_109 = tensor.from_elements %358, %c2_i64 : tensor<2xi64> + %362 = stablehlo.real_dynamic_slice %357, %c_24, %from_elements_109, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_110 = tensor.dim %362, %c0 : tensor + %363 = arith.index_cast %dim_110 : index to i64 + %from_elements_111 = tensor.from_elements %363 : tensor<1xi64> + %364 = stablehlo.dynamic_reshape %362, %from_elements_111 : (tensor, tensor<1xi64>) -> tensor + %dim_112 = tensor.dim %364, %c0 : tensor + %365 = arith.index_cast %dim_112 : index to i64 + %from_elements_113 = tensor.from_elements %365, %c1_i64 : tensor<2xi64> + %366 = stablehlo.dynamic_reshape %364, %from_elements_113 : (tensor, tensor<2xi64>) -> tensor + %dim_114 = tensor.dim %366, %c0 : tensor + %367 = arith.index_cast %dim_114 : index to i64 + %from_elements_115 = tensor.from_elements %c1_i64, %367, %c4096_i64 : tensor<3xi64> + %368 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_115, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_116 = tensor.dim %368, %c1 : tensor<1x?x4096xi64> + %369 = arith.index_cast %dim_116 : index to i64 + %from_elements_117 = tensor.from_elements %c1_i64, %369, %c4096_i64, %c1_i64 : tensor<4xi64> + %370 = stablehlo.dynamic_reshape %368, %from_elements_117 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %371 = stablehlo.dynamic_broadcast_in_dim %366, %from_elements_115, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_118 = tensor.dim %371, %c1 : tensor<1x?x4096xi64> + %372 = arith.index_cast %dim_118 : index to i64 + %from_elements_119 = tensor.from_elements %c1_i64, %372, %c4096_i64, %c1_i64 : tensor<4xi64> + %373 = stablehlo.dynamic_reshape %371, %from_elements_119 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %374 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_115, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_120 = tensor.dim %374, %c1 : tensor<1x?x4096xi64> + %375 = arith.index_cast %dim_120 : index to i64 + %from_elements_121 = tensor.from_elements %c1_i64, %375, %c4096_i64, %c1_i64 : tensor<4xi64> + %376 = stablehlo.dynamic_reshape %374, %from_elements_121 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %377 = stablehlo.concatenate %370, %373, %376, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %378 = "stablehlo.gather"(%151, %377) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %379 = shape.shape_of %378 : tensor<1x?x4096xf32> -> tensor<3xindex> + %380 = shape.num_elements %379 : tensor<3xindex> -> index + %381 = stablehlo.compute_reshape_shape %380, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %382 = stablehlo.dynamic_reshape %378, %381 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %383 = stablehlo.dot %382, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %384 = stablehlo.logistic %383 : tensor + %385 = shape.shape_of %384 : tensor -> tensor<2xindex> + %386 = shape.shape_of %383 : tensor -> tensor<2xindex> + %387 = shape.cstr_broadcastable %385, %386 : tensor<2xindex>, tensor<2xindex> + %388 = shape.assuming %387 -> (tensor) { + %19688 = shape.broadcast %385, %386 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %384, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %383, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %389 = shape.shape_of %388 : tensor -> tensor<2xindex> + %390 = shape.cstr_broadcastable %389, %386 : tensor<2xindex>, tensor<2xindex> + %391 = shape.assuming %390 -> (tensor) { + %19688 = shape.broadcast %389, %386 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %388, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %383, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %392 = stablehlo.dot %391, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_122 = tensor.dim %364, %c0 : tensor + %393 = arith.index_cast %dim_122 : index to i64 + %from_elements_123 = tensor.from_elements %393, %c1_i64 : tensor<2xi64> + %394 = stablehlo.dynamic_reshape %364, %from_elements_123 : (tensor, tensor<2xi64>) -> tensor + %dim_124 = tensor.dim %361, %c0 : tensor + %395 = arith.index_cast %dim_124 : index to i64 + %from_elements_125 = tensor.from_elements %395, %c1_i64 : tensor<2xi64> + %396 = stablehlo.dynamic_reshape %361, %from_elements_125 : (tensor, tensor<2xi64>) -> tensor + %397 = stablehlo.concatenate %394, %396, dim = 1 : (tensor, tensor) -> tensor + %398 = "stablehlo.gather"(%202, %397) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %399 = shape.shape_of %392 : tensor -> tensor<2xindex> + %400 = shape.shape_of %398 : tensor -> tensor<2xindex> + %401 = shape.cstr_broadcastable %399, %400 : tensor<2xindex>, tensor<2xindex> + %402 = shape.assuming %401 -> (tensor) { + %19688 = shape.broadcast %399, %400 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %392, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %398, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %403 = shape.shape_of %402 : tensor -> tensor<2xindex> + %404 = stablehlo.dynamic_broadcast_in_dim %402, %403, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %405 = stablehlo.dynamic_broadcast_in_dim %213, %403, dims = [] : (tensor, tensor<2xindex>) -> tensor + %406 = stablehlo.multiply %404, %405 : tensor + %dim_126 = tensor.dim %366, %c0 : tensor + %407 = arith.index_cast %dim_126 : index to i64 + %dim_127 = tensor.dim %402, %c0 : tensor + %408 = arith.index_cast %dim_127 : index to i64 + %409 = arith.maxsi %407, %408 : i64 + %410 = arith.index_cast %409 : i64 to index + %from_elements_128 = tensor.from_elements %410, %c4096 : tensor<2xindex> + %411 = stablehlo.dynamic_broadcast_in_dim %366, %from_elements_128, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_129 = tensor.dim %411, %c0 : tensor + %412 = arith.index_cast %dim_129 : index to i64 + %from_elements_130 = tensor.from_elements %412, %c4096_i64 : tensor<2xi64> + %413 = stablehlo.real_dynamic_slice %406, %c_22, %from_elements_130, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_131 = tensor.from_elements %412, %c4096_i64, %c1_i64 : tensor<3xi64> + %414 = stablehlo.dynamic_reshape %411, %from_elements_131 : (tensor, tensor<3xi64>) -> tensor + %415 = stablehlo.dynamic_iota %from_elements_131, dim = 1 : (tensor<3xi64>) -> tensor + %416 = stablehlo.concatenate %414, %415, dim = 2 : (tensor, tensor) -> tensor + %417 = "stablehlo.scatter"(%354, %416, %413) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %418 = stablehlo.slice %140 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %419 = stablehlo.reshape %418 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %420 = stablehlo.custom_call @byteir.non_zero(%419) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_132 = tensor.dim %420, %c0 : tensor + %421 = arith.index_cast %dim_132 : index to i64 + %from_elements_133 = tensor.from_elements %421, %c1_i64 : tensor<2xi64> + %422 = stablehlo.real_dynamic_slice %420, %c_22, %from_elements_133, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_134 = tensor.dim %422, %c0 : tensor + %423 = arith.index_cast %dim_134 : index to i64 + %from_elements_135 = tensor.from_elements %423 : tensor<1xi64> + %424 = stablehlo.dynamic_reshape %422, %from_elements_135 : (tensor, tensor<1xi64>) -> tensor + %from_elements_136 = tensor.from_elements %421, %c2_i64 : tensor<2xi64> + %425 = stablehlo.real_dynamic_slice %420, %c_24, %from_elements_136, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_137 = tensor.dim %425, %c0 : tensor + %426 = arith.index_cast %dim_137 : index to i64 + %from_elements_138 = tensor.from_elements %426 : tensor<1xi64> + %427 = stablehlo.dynamic_reshape %425, %from_elements_138 : (tensor, tensor<1xi64>) -> tensor + %dim_139 = tensor.dim %427, %c0 : tensor + %428 = arith.index_cast %dim_139 : index to i64 + %from_elements_140 = tensor.from_elements %428, %c1_i64 : tensor<2xi64> + %429 = stablehlo.dynamic_reshape %427, %from_elements_140 : (tensor, tensor<2xi64>) -> tensor + %dim_141 = tensor.dim %429, %c0 : tensor + %430 = arith.index_cast %dim_141 : index to i64 + %from_elements_142 = tensor.from_elements %c1_i64, %430, %c4096_i64 : tensor<3xi64> + %431 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_142, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_143 = tensor.dim %431, %c1 : tensor<1x?x4096xi64> + %432 = arith.index_cast %dim_143 : index to i64 + %from_elements_144 = tensor.from_elements %c1_i64, %432, %c4096_i64, %c1_i64 : tensor<4xi64> + %433 = stablehlo.dynamic_reshape %431, %from_elements_144 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %434 = stablehlo.dynamic_broadcast_in_dim %429, %from_elements_142, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_145 = tensor.dim %434, %c1 : tensor<1x?x4096xi64> + %435 = arith.index_cast %dim_145 : index to i64 + %from_elements_146 = tensor.from_elements %c1_i64, %435, %c4096_i64, %c1_i64 : tensor<4xi64> + %436 = stablehlo.dynamic_reshape %434, %from_elements_146 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %437 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_142, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_147 = tensor.dim %437, %c1 : tensor<1x?x4096xi64> + %438 = arith.index_cast %dim_147 : index to i64 + %from_elements_148 = tensor.from_elements %c1_i64, %438, %c4096_i64, %c1_i64 : tensor<4xi64> + %439 = stablehlo.dynamic_reshape %437, %from_elements_148 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %440 = stablehlo.concatenate %433, %436, %439, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %441 = "stablehlo.gather"(%151, %440) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %442 = shape.shape_of %441 : tensor<1x?x4096xf32> -> tensor<3xindex> + %443 = shape.num_elements %442 : tensor<3xindex> -> index + %444 = stablehlo.compute_reshape_shape %443, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %445 = stablehlo.dynamic_reshape %441, %444 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %446 = stablehlo.dot %445, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %447 = stablehlo.logistic %446 : tensor + %448 = shape.shape_of %447 : tensor -> tensor<2xindex> + %449 = shape.shape_of %446 : tensor -> tensor<2xindex> + %450 = shape.cstr_broadcastable %448, %449 : tensor<2xindex>, tensor<2xindex> + %451 = shape.assuming %450 -> (tensor) { + %19688 = shape.broadcast %448, %449 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %447, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %446, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %452 = shape.shape_of %451 : tensor -> tensor<2xindex> + %453 = shape.cstr_broadcastable %452, %449 : tensor<2xindex>, tensor<2xindex> + %454 = shape.assuming %453 -> (tensor) { + %19688 = shape.broadcast %452, %449 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %451, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %446, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %455 = stablehlo.dot %454, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_149 = tensor.dim %427, %c0 : tensor + %456 = arith.index_cast %dim_149 : index to i64 + %from_elements_150 = tensor.from_elements %456, %c1_i64 : tensor<2xi64> + %457 = stablehlo.dynamic_reshape %427, %from_elements_150 : (tensor, tensor<2xi64>) -> tensor + %dim_151 = tensor.dim %424, %c0 : tensor + %458 = arith.index_cast %dim_151 : index to i64 + %from_elements_152 = tensor.from_elements %458, %c1_i64 : tensor<2xi64> + %459 = stablehlo.dynamic_reshape %424, %from_elements_152 : (tensor, tensor<2xi64>) -> tensor + %460 = stablehlo.concatenate %457, %459, dim = 1 : (tensor, tensor) -> tensor + %461 = "stablehlo.gather"(%202, %460) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %462 = shape.shape_of %455 : tensor -> tensor<2xindex> + %463 = shape.shape_of %461 : tensor -> tensor<2xindex> + %464 = shape.cstr_broadcastable %462, %463 : tensor<2xindex>, tensor<2xindex> + %465 = shape.assuming %464 -> (tensor) { + %19688 = shape.broadcast %462, %463 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %455, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %461, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %466 = shape.shape_of %465 : tensor -> tensor<2xindex> + %467 = stablehlo.dynamic_broadcast_in_dim %465, %466, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %468 = stablehlo.dynamic_broadcast_in_dim %213, %466, dims = [] : (tensor, tensor<2xindex>) -> tensor + %469 = stablehlo.multiply %467, %468 : tensor + %dim_153 = tensor.dim %429, %c0 : tensor + %470 = arith.index_cast %dim_153 : index to i64 + %dim_154 = tensor.dim %465, %c0 : tensor + %471 = arith.index_cast %dim_154 : index to i64 + %472 = arith.maxsi %470, %471 : i64 + %473 = arith.index_cast %472 : i64 to index + %from_elements_155 = tensor.from_elements %473, %c4096 : tensor<2xindex> + %474 = stablehlo.dynamic_broadcast_in_dim %429, %from_elements_155, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_156 = tensor.dim %474, %c0 : tensor + %475 = arith.index_cast %dim_156 : index to i64 + %from_elements_157 = tensor.from_elements %475, %c4096_i64 : tensor<2xi64> + %476 = stablehlo.real_dynamic_slice %469, %c_22, %from_elements_157, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_158 = tensor.from_elements %475, %c4096_i64, %c1_i64 : tensor<3xi64> + %477 = stablehlo.dynamic_reshape %474, %from_elements_158 : (tensor, tensor<3xi64>) -> tensor + %478 = stablehlo.dynamic_iota %from_elements_158, dim = 1 : (tensor<3xi64>) -> tensor + %479 = stablehlo.concatenate %477, %478, dim = 2 : (tensor, tensor) -> tensor + %480 = "stablehlo.scatter"(%417, %479, %476) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %481 = stablehlo.slice %140 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %482 = stablehlo.reshape %481 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %483 = stablehlo.custom_call @byteir.non_zero(%482) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_159 = tensor.dim %483, %c0 : tensor + %484 = arith.index_cast %dim_159 : index to i64 + %from_elements_160 = tensor.from_elements %484, %c1_i64 : tensor<2xi64> + %485 = stablehlo.real_dynamic_slice %483, %c_22, %from_elements_160, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_161 = tensor.dim %485, %c0 : tensor + %486 = arith.index_cast %dim_161 : index to i64 + %from_elements_162 = tensor.from_elements %486 : tensor<1xi64> + %487 = stablehlo.dynamic_reshape %485, %from_elements_162 : (tensor, tensor<1xi64>) -> tensor + %from_elements_163 = tensor.from_elements %484, %c2_i64 : tensor<2xi64> + %488 = stablehlo.real_dynamic_slice %483, %c_24, %from_elements_163, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_164 = tensor.dim %488, %c0 : tensor + %489 = arith.index_cast %dim_164 : index to i64 + %from_elements_165 = tensor.from_elements %489 : tensor<1xi64> + %490 = stablehlo.dynamic_reshape %488, %from_elements_165 : (tensor, tensor<1xi64>) -> tensor + %dim_166 = tensor.dim %490, %c0 : tensor + %491 = arith.index_cast %dim_166 : index to i64 + %from_elements_167 = tensor.from_elements %491, %c1_i64 : tensor<2xi64> + %492 = stablehlo.dynamic_reshape %490, %from_elements_167 : (tensor, tensor<2xi64>) -> tensor + %dim_168 = tensor.dim %492, %c0 : tensor + %493 = arith.index_cast %dim_168 : index to i64 + %from_elements_169 = tensor.from_elements %c1_i64, %493, %c4096_i64 : tensor<3xi64> + %494 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_169, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_170 = tensor.dim %494, %c1 : tensor<1x?x4096xi64> + %495 = arith.index_cast %dim_170 : index to i64 + %from_elements_171 = tensor.from_elements %c1_i64, %495, %c4096_i64, %c1_i64 : tensor<4xi64> + %496 = stablehlo.dynamic_reshape %494, %from_elements_171 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %497 = stablehlo.dynamic_broadcast_in_dim %492, %from_elements_169, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_172 = tensor.dim %497, %c1 : tensor<1x?x4096xi64> + %498 = arith.index_cast %dim_172 : index to i64 + %from_elements_173 = tensor.from_elements %c1_i64, %498, %c4096_i64, %c1_i64 : tensor<4xi64> + %499 = stablehlo.dynamic_reshape %497, %from_elements_173 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %500 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_169, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_174 = tensor.dim %500, %c1 : tensor<1x?x4096xi64> + %501 = arith.index_cast %dim_174 : index to i64 + %from_elements_175 = tensor.from_elements %c1_i64, %501, %c4096_i64, %c1_i64 : tensor<4xi64> + %502 = stablehlo.dynamic_reshape %500, %from_elements_175 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %503 = stablehlo.concatenate %496, %499, %502, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %504 = "stablehlo.gather"(%151, %503) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %505 = shape.shape_of %504 : tensor<1x?x4096xf32> -> tensor<3xindex> + %506 = shape.num_elements %505 : tensor<3xindex> -> index + %507 = stablehlo.compute_reshape_shape %506, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %508 = stablehlo.dynamic_reshape %504, %507 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %509 = stablehlo.dot %508, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %510 = stablehlo.logistic %509 : tensor + %511 = shape.shape_of %510 : tensor -> tensor<2xindex> + %512 = shape.shape_of %509 : tensor -> tensor<2xindex> + %513 = shape.cstr_broadcastable %511, %512 : tensor<2xindex>, tensor<2xindex> + %514 = shape.assuming %513 -> (tensor) { + %19688 = shape.broadcast %511, %512 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %510, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %509, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %515 = shape.shape_of %514 : tensor -> tensor<2xindex> + %516 = shape.cstr_broadcastable %515, %512 : tensor<2xindex>, tensor<2xindex> + %517 = shape.assuming %516 -> (tensor) { + %19688 = shape.broadcast %515, %512 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %514, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %509, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %518 = stablehlo.dot %517, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_176 = tensor.dim %490, %c0 : tensor + %519 = arith.index_cast %dim_176 : index to i64 + %from_elements_177 = tensor.from_elements %519, %c1_i64 : tensor<2xi64> + %520 = stablehlo.dynamic_reshape %490, %from_elements_177 : (tensor, tensor<2xi64>) -> tensor + %dim_178 = tensor.dim %487, %c0 : tensor + %521 = arith.index_cast %dim_178 : index to i64 + %from_elements_179 = tensor.from_elements %521, %c1_i64 : tensor<2xi64> + %522 = stablehlo.dynamic_reshape %487, %from_elements_179 : (tensor, tensor<2xi64>) -> tensor + %523 = stablehlo.concatenate %520, %522, dim = 1 : (tensor, tensor) -> tensor + %524 = "stablehlo.gather"(%202, %523) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %525 = shape.shape_of %518 : tensor -> tensor<2xindex> + %526 = shape.shape_of %524 : tensor -> tensor<2xindex> + %527 = shape.cstr_broadcastable %525, %526 : tensor<2xindex>, tensor<2xindex> + %528 = shape.assuming %527 -> (tensor) { + %19688 = shape.broadcast %525, %526 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %518, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %524, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %529 = shape.shape_of %528 : tensor -> tensor<2xindex> + %530 = stablehlo.dynamic_broadcast_in_dim %528, %529, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %531 = stablehlo.dynamic_broadcast_in_dim %213, %529, dims = [] : (tensor, tensor<2xindex>) -> tensor + %532 = stablehlo.multiply %530, %531 : tensor + %dim_180 = tensor.dim %492, %c0 : tensor + %533 = arith.index_cast %dim_180 : index to i64 + %dim_181 = tensor.dim %528, %c0 : tensor + %534 = arith.index_cast %dim_181 : index to i64 + %535 = arith.maxsi %533, %534 : i64 + %536 = arith.index_cast %535 : i64 to index + %from_elements_182 = tensor.from_elements %536, %c4096 : tensor<2xindex> + %537 = stablehlo.dynamic_broadcast_in_dim %492, %from_elements_182, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_183 = tensor.dim %537, %c0 : tensor + %538 = arith.index_cast %dim_183 : index to i64 + %from_elements_184 = tensor.from_elements %538, %c4096_i64 : tensor<2xi64> + %539 = stablehlo.real_dynamic_slice %532, %c_22, %from_elements_184, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_185 = tensor.from_elements %538, %c4096_i64, %c1_i64 : tensor<3xi64> + %540 = stablehlo.dynamic_reshape %537, %from_elements_185 : (tensor, tensor<3xi64>) -> tensor + %541 = stablehlo.dynamic_iota %from_elements_185, dim = 1 : (tensor<3xi64>) -> tensor + %542 = stablehlo.concatenate %540, %541, dim = 2 : (tensor, tensor) -> tensor + %543 = "stablehlo.scatter"(%480, %542, %539) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %544 = stablehlo.slice %140 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %545 = stablehlo.reshape %544 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %546 = stablehlo.custom_call @byteir.non_zero(%545) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_186 = tensor.dim %546, %c0 : tensor + %547 = arith.index_cast %dim_186 : index to i64 + %from_elements_187 = tensor.from_elements %547, %c1_i64 : tensor<2xi64> + %548 = stablehlo.real_dynamic_slice %546, %c_22, %from_elements_187, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_188 = tensor.dim %548, %c0 : tensor + %549 = arith.index_cast %dim_188 : index to i64 + %from_elements_189 = tensor.from_elements %549 : tensor<1xi64> + %550 = stablehlo.dynamic_reshape %548, %from_elements_189 : (tensor, tensor<1xi64>) -> tensor + %from_elements_190 = tensor.from_elements %547, %c2_i64 : tensor<2xi64> + %551 = stablehlo.real_dynamic_slice %546, %c_24, %from_elements_190, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_191 = tensor.dim %551, %c0 : tensor + %552 = arith.index_cast %dim_191 : index to i64 + %from_elements_192 = tensor.from_elements %552 : tensor<1xi64> + %553 = stablehlo.dynamic_reshape %551, %from_elements_192 : (tensor, tensor<1xi64>) -> tensor + %dim_193 = tensor.dim %553, %c0 : tensor + %554 = arith.index_cast %dim_193 : index to i64 + %from_elements_194 = tensor.from_elements %554, %c1_i64 : tensor<2xi64> + %555 = stablehlo.dynamic_reshape %553, %from_elements_194 : (tensor, tensor<2xi64>) -> tensor + %dim_195 = tensor.dim %555, %c0 : tensor + %556 = arith.index_cast %dim_195 : index to i64 + %from_elements_196 = tensor.from_elements %c1_i64, %556, %c4096_i64 : tensor<3xi64> + %557 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_196, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_197 = tensor.dim %557, %c1 : tensor<1x?x4096xi64> + %558 = arith.index_cast %dim_197 : index to i64 + %from_elements_198 = tensor.from_elements %c1_i64, %558, %c4096_i64, %c1_i64 : tensor<4xi64> + %559 = stablehlo.dynamic_reshape %557, %from_elements_198 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %560 = stablehlo.dynamic_broadcast_in_dim %555, %from_elements_196, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_199 = tensor.dim %560, %c1 : tensor<1x?x4096xi64> + %561 = arith.index_cast %dim_199 : index to i64 + %from_elements_200 = tensor.from_elements %c1_i64, %561, %c4096_i64, %c1_i64 : tensor<4xi64> + %562 = stablehlo.dynamic_reshape %560, %from_elements_200 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %563 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_196, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_201 = tensor.dim %563, %c1 : tensor<1x?x4096xi64> + %564 = arith.index_cast %dim_201 : index to i64 + %from_elements_202 = tensor.from_elements %c1_i64, %564, %c4096_i64, %c1_i64 : tensor<4xi64> + %565 = stablehlo.dynamic_reshape %563, %from_elements_202 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %566 = stablehlo.concatenate %559, %562, %565, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %567 = "stablehlo.gather"(%151, %566) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %568 = shape.shape_of %567 : tensor<1x?x4096xf32> -> tensor<3xindex> + %569 = shape.num_elements %568 : tensor<3xindex> -> index + %570 = stablehlo.compute_reshape_shape %569, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %571 = stablehlo.dynamic_reshape %567, %570 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %572 = stablehlo.dot %571, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %573 = stablehlo.logistic %572 : tensor + %574 = shape.shape_of %573 : tensor -> tensor<2xindex> + %575 = shape.shape_of %572 : tensor -> tensor<2xindex> + %576 = shape.cstr_broadcastable %574, %575 : tensor<2xindex>, tensor<2xindex> + %577 = shape.assuming %576 -> (tensor) { + %19688 = shape.broadcast %574, %575 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %573, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %572, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %578 = shape.shape_of %577 : tensor -> tensor<2xindex> + %579 = shape.cstr_broadcastable %578, %575 : tensor<2xindex>, tensor<2xindex> + %580 = shape.assuming %579 -> (tensor) { + %19688 = shape.broadcast %578, %575 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %577, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %572, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %581 = stablehlo.dot %580, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_203 = tensor.dim %553, %c0 : tensor + %582 = arith.index_cast %dim_203 : index to i64 + %from_elements_204 = tensor.from_elements %582, %c1_i64 : tensor<2xi64> + %583 = stablehlo.dynamic_reshape %553, %from_elements_204 : (tensor, tensor<2xi64>) -> tensor + %dim_205 = tensor.dim %550, %c0 : tensor + %584 = arith.index_cast %dim_205 : index to i64 + %from_elements_206 = tensor.from_elements %584, %c1_i64 : tensor<2xi64> + %585 = stablehlo.dynamic_reshape %550, %from_elements_206 : (tensor, tensor<2xi64>) -> tensor + %586 = stablehlo.concatenate %583, %585, dim = 1 : (tensor, tensor) -> tensor + %587 = "stablehlo.gather"(%202, %586) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %588 = shape.shape_of %581 : tensor -> tensor<2xindex> + %589 = shape.shape_of %587 : tensor -> tensor<2xindex> + %590 = shape.cstr_broadcastable %588, %589 : tensor<2xindex>, tensor<2xindex> + %591 = shape.assuming %590 -> (tensor) { + %19688 = shape.broadcast %588, %589 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %581, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %587, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %592 = shape.shape_of %591 : tensor -> tensor<2xindex> + %593 = stablehlo.dynamic_broadcast_in_dim %591, %592, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %594 = stablehlo.dynamic_broadcast_in_dim %213, %592, dims = [] : (tensor, tensor<2xindex>) -> tensor + %595 = stablehlo.multiply %593, %594 : tensor + %dim_207 = tensor.dim %555, %c0 : tensor + %596 = arith.index_cast %dim_207 : index to i64 + %dim_208 = tensor.dim %591, %c0 : tensor + %597 = arith.index_cast %dim_208 : index to i64 + %598 = arith.maxsi %596, %597 : i64 + %599 = arith.index_cast %598 : i64 to index + %from_elements_209 = tensor.from_elements %599, %c4096 : tensor<2xindex> + %600 = stablehlo.dynamic_broadcast_in_dim %555, %from_elements_209, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_210 = tensor.dim %600, %c0 : tensor + %601 = arith.index_cast %dim_210 : index to i64 + %from_elements_211 = tensor.from_elements %601, %c4096_i64 : tensor<2xi64> + %602 = stablehlo.real_dynamic_slice %595, %c_22, %from_elements_211, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_212 = tensor.from_elements %601, %c4096_i64, %c1_i64 : tensor<3xi64> + %603 = stablehlo.dynamic_reshape %600, %from_elements_212 : (tensor, tensor<3xi64>) -> tensor + %604 = stablehlo.dynamic_iota %from_elements_212, dim = 1 : (tensor<3xi64>) -> tensor + %605 = stablehlo.concatenate %603, %604, dim = 2 : (tensor, tensor) -> tensor + %606 = "stablehlo.scatter"(%543, %605, %602) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %607 = stablehlo.slice %140 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %608 = stablehlo.reshape %607 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %609 = stablehlo.custom_call @byteir.non_zero(%608) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_213 = tensor.dim %609, %c0 : tensor + %610 = arith.index_cast %dim_213 : index to i64 + %from_elements_214 = tensor.from_elements %610, %c1_i64 : tensor<2xi64> + %611 = stablehlo.real_dynamic_slice %609, %c_22, %from_elements_214, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_215 = tensor.dim %611, %c0 : tensor + %612 = arith.index_cast %dim_215 : index to i64 + %from_elements_216 = tensor.from_elements %612 : tensor<1xi64> + %613 = stablehlo.dynamic_reshape %611, %from_elements_216 : (tensor, tensor<1xi64>) -> tensor + %from_elements_217 = tensor.from_elements %610, %c2_i64 : tensor<2xi64> + %614 = stablehlo.real_dynamic_slice %609, %c_24, %from_elements_217, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_218 = tensor.dim %614, %c0 : tensor + %615 = arith.index_cast %dim_218 : index to i64 + %from_elements_219 = tensor.from_elements %615 : tensor<1xi64> + %616 = stablehlo.dynamic_reshape %614, %from_elements_219 : (tensor, tensor<1xi64>) -> tensor + %dim_220 = tensor.dim %616, %c0 : tensor + %617 = arith.index_cast %dim_220 : index to i64 + %from_elements_221 = tensor.from_elements %617, %c1_i64 : tensor<2xi64> + %618 = stablehlo.dynamic_reshape %616, %from_elements_221 : (tensor, tensor<2xi64>) -> tensor + %dim_222 = tensor.dim %618, %c0 : tensor + %619 = arith.index_cast %dim_222 : index to i64 + %from_elements_223 = tensor.from_elements %c1_i64, %619, %c4096_i64 : tensor<3xi64> + %620 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_223, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_224 = tensor.dim %620, %c1 : tensor<1x?x4096xi64> + %621 = arith.index_cast %dim_224 : index to i64 + %from_elements_225 = tensor.from_elements %c1_i64, %621, %c4096_i64, %c1_i64 : tensor<4xi64> + %622 = stablehlo.dynamic_reshape %620, %from_elements_225 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %623 = stablehlo.dynamic_broadcast_in_dim %618, %from_elements_223, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_226 = tensor.dim %623, %c1 : tensor<1x?x4096xi64> + %624 = arith.index_cast %dim_226 : index to i64 + %from_elements_227 = tensor.from_elements %c1_i64, %624, %c4096_i64, %c1_i64 : tensor<4xi64> + %625 = stablehlo.dynamic_reshape %623, %from_elements_227 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %626 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_223, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_228 = tensor.dim %626, %c1 : tensor<1x?x4096xi64> + %627 = arith.index_cast %dim_228 : index to i64 + %from_elements_229 = tensor.from_elements %c1_i64, %627, %c4096_i64, %c1_i64 : tensor<4xi64> + %628 = stablehlo.dynamic_reshape %626, %from_elements_229 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %629 = stablehlo.concatenate %622, %625, %628, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %630 = "stablehlo.gather"(%151, %629) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %631 = shape.shape_of %630 : tensor<1x?x4096xf32> -> tensor<3xindex> + %632 = shape.num_elements %631 : tensor<3xindex> -> index + %633 = stablehlo.compute_reshape_shape %632, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %634 = stablehlo.dynamic_reshape %630, %633 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %635 = stablehlo.dot %634, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %636 = stablehlo.logistic %635 : tensor + %637 = shape.shape_of %636 : tensor -> tensor<2xindex> + %638 = shape.shape_of %635 : tensor -> tensor<2xindex> + %639 = shape.cstr_broadcastable %637, %638 : tensor<2xindex>, tensor<2xindex> + %640 = shape.assuming %639 -> (tensor) { + %19688 = shape.broadcast %637, %638 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %636, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %635, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %641 = shape.shape_of %640 : tensor -> tensor<2xindex> + %642 = shape.cstr_broadcastable %641, %638 : tensor<2xindex>, tensor<2xindex> + %643 = shape.assuming %642 -> (tensor) { + %19688 = shape.broadcast %641, %638 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %640, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %635, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %644 = stablehlo.dot %643, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_230 = tensor.dim %616, %c0 : tensor + %645 = arith.index_cast %dim_230 : index to i64 + %from_elements_231 = tensor.from_elements %645, %c1_i64 : tensor<2xi64> + %646 = stablehlo.dynamic_reshape %616, %from_elements_231 : (tensor, tensor<2xi64>) -> tensor + %dim_232 = tensor.dim %613, %c0 : tensor + %647 = arith.index_cast %dim_232 : index to i64 + %from_elements_233 = tensor.from_elements %647, %c1_i64 : tensor<2xi64> + %648 = stablehlo.dynamic_reshape %613, %from_elements_233 : (tensor, tensor<2xi64>) -> tensor + %649 = stablehlo.concatenate %646, %648, dim = 1 : (tensor, tensor) -> tensor + %650 = "stablehlo.gather"(%202, %649) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %651 = shape.shape_of %644 : tensor -> tensor<2xindex> + %652 = shape.shape_of %650 : tensor -> tensor<2xindex> + %653 = shape.cstr_broadcastable %651, %652 : tensor<2xindex>, tensor<2xindex> + %654 = shape.assuming %653 -> (tensor) { + %19688 = shape.broadcast %651, %652 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %644, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %650, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %655 = shape.shape_of %654 : tensor -> tensor<2xindex> + %656 = stablehlo.dynamic_broadcast_in_dim %654, %655, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %657 = stablehlo.dynamic_broadcast_in_dim %213, %655, dims = [] : (tensor, tensor<2xindex>) -> tensor + %658 = stablehlo.multiply %656, %657 : tensor + %dim_234 = tensor.dim %618, %c0 : tensor + %659 = arith.index_cast %dim_234 : index to i64 + %dim_235 = tensor.dim %654, %c0 : tensor + %660 = arith.index_cast %dim_235 : index to i64 + %661 = arith.maxsi %659, %660 : i64 + %662 = arith.index_cast %661 : i64 to index + %from_elements_236 = tensor.from_elements %662, %c4096 : tensor<2xindex> + %663 = stablehlo.dynamic_broadcast_in_dim %618, %from_elements_236, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_237 = tensor.dim %663, %c0 : tensor + %664 = arith.index_cast %dim_237 : index to i64 + %from_elements_238 = tensor.from_elements %664, %c4096_i64 : tensor<2xi64> + %665 = stablehlo.real_dynamic_slice %658, %c_22, %from_elements_238, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_239 = tensor.from_elements %664, %c4096_i64, %c1_i64 : tensor<3xi64> + %666 = stablehlo.dynamic_reshape %663, %from_elements_239 : (tensor, tensor<3xi64>) -> tensor + %667 = stablehlo.dynamic_iota %from_elements_239, dim = 1 : (tensor<3xi64>) -> tensor + %668 = stablehlo.concatenate %666, %667, dim = 2 : (tensor, tensor) -> tensor + %669 = "stablehlo.scatter"(%606, %668, %665) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %670 = stablehlo.reshape %669 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %671 = stablehlo.add %102, %670 : tensor<3x1x4096xf32> + %672 = stablehlo.broadcast_in_dim %671, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %673 = stablehlo.power %672, %15 : tensor<3x1x4096xf32> + %674 = stablehlo.reduce(%673 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %675 = stablehlo.reshape %674 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %676 = stablehlo.broadcast_in_dim %675, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %677 = stablehlo.divide %676, %21 : tensor<3x1x1xf32> + %678 = stablehlo.broadcast_in_dim %677, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %679 = stablehlo.add %678, %25 : tensor<3x1x1xf32> + %680 = stablehlo.rsqrt %679 : tensor<3x1x1xf32> + %681 = stablehlo.broadcast_in_dim %680, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %682 = stablehlo.multiply %672, %681 : tensor<3x1x4096xf32> + %683 = stablehlo.broadcast_in_dim %682, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %684 = stablehlo.multiply %683, %31 : tensor<3x1x4096xf32> + %685 = stablehlo.reshape %684 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %686 = stablehlo.dot %685, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %687 = stablehlo.reshape %686 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %688 = stablehlo.dot %685, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %689 = stablehlo.reshape %688 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %690 = stablehlo.reshape %687 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %691 = stablehlo.transpose %690, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %692 = stablehlo.reshape %689 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %693 = stablehlo.transpose %692, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %694 = stablehlo.slice %arg2 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %695 = stablehlo.slice %arg3 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %696 = "stablehlo.gather"(%694, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %697 = stablehlo.reshape %696 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %698 = "stablehlo.gather"(%695, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %699 = stablehlo.reshape %698 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %700 = stablehlo.broadcast_in_dim %691, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %701 = stablehlo.broadcast_in_dim %697, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %702 = stablehlo.multiply %700, %701 : tensor<3x32x1x128xf32> + %703 = stablehlo.slice %691 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %704 = stablehlo.slice %691 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %705 = stablehlo.negate %704 : tensor<3x32x1x64xf32> + %706 = stablehlo.concatenate %705, %703, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %707 = stablehlo.broadcast_in_dim %706, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %708 = stablehlo.broadcast_in_dim %699, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %709 = stablehlo.multiply %707, %708 : tensor<3x32x1x128xf32> + %710 = stablehlo.add %702, %709 : tensor<3x32x1x128xf32> + %711 = stablehlo.broadcast_in_dim %693, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %712 = stablehlo.broadcast_in_dim %697, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %713 = stablehlo.multiply %711, %712 : tensor<3x8x1x128xf32> + %714 = stablehlo.slice %693 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %715 = stablehlo.slice %693 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %716 = stablehlo.negate %715 : tensor<3x8x1x64xf32> + %717 = stablehlo.concatenate %716, %714, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %718 = stablehlo.broadcast_in_dim %717, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %719 = stablehlo.broadcast_in_dim %699, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %720 = stablehlo.multiply %718, %719 : tensor<3x8x1x128xf32> + %721 = stablehlo.add %713, %720 : tensor<3x8x1x128xf32> + %722 = stablehlo.concatenate %arg67, %721, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %723 = stablehlo.concatenate %arg68, %693, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %724 = stablehlo.reshape %722 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %725 = stablehlo.broadcast_in_dim %724, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %726 = stablehlo.reshape %725 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %727 = stablehlo.reshape %723 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %728 = stablehlo.broadcast_in_dim %727, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %729 = stablehlo.reshape %728 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %730 = stablehlo.transpose %726, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %731 = stablehlo.reshape %710 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %732 = stablehlo.reshape %730 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %733 = stablehlo.broadcast_in_dim %732, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %734 = stablehlo.dot_general %731, %733, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %735 = stablehlo.reshape %734 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %736 = stablehlo.broadcast_in_dim %735, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %737 = stablehlo.divide %736, %89 : tensor<3x32x1x8xf32> + %738 = stablehlo.custom_call @byteir.softmax(%737) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %739 = stablehlo.reshape %738 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %740 = stablehlo.reshape %729 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %741 = stablehlo.broadcast_in_dim %740, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %742 = stablehlo.dot_general %739, %741, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %743 = stablehlo.reshape %742 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %744 = stablehlo.transpose %743, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %745 = stablehlo.reshape %744 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %746 = stablehlo.reshape %745 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %747 = stablehlo.dot %746, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %748 = stablehlo.reshape %747 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %749 = stablehlo.add %671, %748 : tensor<3x1x4096xf32> + %750 = stablehlo.broadcast_in_dim %749, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %751 = stablehlo.power %750, %15 : tensor<3x1x4096xf32> + %752 = stablehlo.reduce(%751 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %753 = stablehlo.reshape %752 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %754 = stablehlo.broadcast_in_dim %753, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %755 = stablehlo.divide %754, %21 : tensor<3x1x1xf32> + %756 = stablehlo.broadcast_in_dim %755, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %757 = stablehlo.add %756, %25 : tensor<3x1x1xf32> + %758 = stablehlo.rsqrt %757 : tensor<3x1x1xf32> + %759 = stablehlo.broadcast_in_dim %758, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %760 = stablehlo.multiply %750, %759 : tensor<3x1x4096xf32> + %761 = stablehlo.broadcast_in_dim %760, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %762 = stablehlo.multiply %761, %31 : tensor<3x1x4096xf32> + %763 = stablehlo.reshape %762 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %764 = stablehlo.dot %763, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %765 = stablehlo.custom_call @byteir.softmax(%764) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %766:2 = stablehlo.custom_call @byteir.top_k(%765) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %767 = stablehlo.reduce(%766#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %768 = stablehlo.reshape %767 : (tensor<3xf32>) -> tensor<3x1xf32> + %769 = stablehlo.broadcast_in_dim %766#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %770 = stablehlo.broadcast_in_dim %768, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %771 = stablehlo.divide %769, %770 : tensor<3x2xf32> + %772 = stablehlo.reshape %766#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %773 = stablehlo.broadcast_in_dim %772, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %774 = stablehlo.compare EQ, %773, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %775 = stablehlo.convert %774 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %776 = stablehlo.transpose %775, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %777 = stablehlo.slice %776 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %778 = stablehlo.reshape %777 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %779 = stablehlo.custom_call @byteir.non_zero(%778) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_240 = tensor.dim %779, %c0 : tensor + %780 = arith.index_cast %dim_240 : index to i64 + %from_elements_241 = tensor.from_elements %780, %c1_i64 : tensor<2xi64> + %781 = stablehlo.real_dynamic_slice %779, %c_22, %from_elements_241, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_242 = tensor.dim %781, %c0 : tensor + %782 = arith.index_cast %dim_242 : index to i64 + %from_elements_243 = tensor.from_elements %782 : tensor<1xi64> + %783 = stablehlo.dynamic_reshape %781, %from_elements_243 : (tensor, tensor<1xi64>) -> tensor + %from_elements_244 = tensor.from_elements %780, %c2_i64 : tensor<2xi64> + %784 = stablehlo.real_dynamic_slice %779, %c_24, %from_elements_244, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_245 = tensor.dim %784, %c0 : tensor + %785 = arith.index_cast %dim_245 : index to i64 + %from_elements_246 = tensor.from_elements %785 : tensor<1xi64> + %786 = stablehlo.dynamic_reshape %784, %from_elements_246 : (tensor, tensor<1xi64>) -> tensor + %787 = stablehlo.reshape %763 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_247 = tensor.dim %786, %c0 : tensor + %788 = arith.index_cast %dim_247 : index to i64 + %from_elements_248 = tensor.from_elements %788, %c1_i64 : tensor<2xi64> + %789 = stablehlo.dynamic_reshape %786, %from_elements_248 : (tensor, tensor<2xi64>) -> tensor + %dim_249 = tensor.dim %789, %c0 : tensor + %790 = arith.index_cast %dim_249 : index to i64 + %from_elements_250 = tensor.from_elements %c1_i64, %790, %c4096_i64 : tensor<3xi64> + %791 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_250, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_251 = tensor.dim %791, %c1 : tensor<1x?x4096xi64> + %792 = arith.index_cast %dim_251 : index to i64 + %from_elements_252 = tensor.from_elements %c1_i64, %792, %c4096_i64, %c1_i64 : tensor<4xi64> + %793 = stablehlo.dynamic_reshape %791, %from_elements_252 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %794 = stablehlo.dynamic_broadcast_in_dim %789, %from_elements_250, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_253 = tensor.dim %794, %c1 : tensor<1x?x4096xi64> + %795 = arith.index_cast %dim_253 : index to i64 + %from_elements_254 = tensor.from_elements %c1_i64, %795, %c4096_i64, %c1_i64 : tensor<4xi64> + %796 = stablehlo.dynamic_reshape %794, %from_elements_254 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %797 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_250, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_255 = tensor.dim %797, %c1 : tensor<1x?x4096xi64> + %798 = arith.index_cast %dim_255 : index to i64 + %from_elements_256 = tensor.from_elements %c1_i64, %798, %c4096_i64, %c1_i64 : tensor<4xi64> + %799 = stablehlo.dynamic_reshape %797, %from_elements_256 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %800 = stablehlo.concatenate %793, %796, %799, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %801 = "stablehlo.gather"(%787, %800) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %802 = shape.shape_of %801 : tensor<1x?x4096xf32> -> tensor<3xindex> + %803 = shape.num_elements %802 : tensor<3xindex> -> index + %804 = stablehlo.compute_reshape_shape %803, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %805 = stablehlo.dynamic_reshape %801, %804 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %806 = stablehlo.dot %805, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %807 = stablehlo.logistic %806 : tensor + %808 = shape.shape_of %807 : tensor -> tensor<2xindex> + %809 = shape.shape_of %806 : tensor -> tensor<2xindex> + %810 = shape.cstr_broadcastable %808, %809 : tensor<2xindex>, tensor<2xindex> + %811 = shape.assuming %810 -> (tensor) { + %19688 = shape.broadcast %808, %809 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %807, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %806, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %812 = shape.shape_of %811 : tensor -> tensor<2xindex> + %813 = shape.cstr_broadcastable %812, %809 : tensor<2xindex>, tensor<2xindex> + %814 = shape.assuming %813 -> (tensor) { + %19688 = shape.broadcast %812, %809 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %811, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %806, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %815 = stablehlo.dot %814, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %816 = stablehlo.reshape %771 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_257 = tensor.dim %786, %c0 : tensor + %817 = arith.index_cast %dim_257 : index to i64 + %from_elements_258 = tensor.from_elements %817, %c1_i64 : tensor<2xi64> + %818 = stablehlo.dynamic_reshape %786, %from_elements_258 : (tensor, tensor<2xi64>) -> tensor + %dim_259 = tensor.dim %783, %c0 : tensor + %819 = arith.index_cast %dim_259 : index to i64 + %from_elements_260 = tensor.from_elements %819, %c1_i64 : tensor<2xi64> + %820 = stablehlo.dynamic_reshape %783, %from_elements_260 : (tensor, tensor<2xi64>) -> tensor + %821 = stablehlo.concatenate %818, %820, dim = 1 : (tensor, tensor) -> tensor + %822 = "stablehlo.gather"(%816, %821) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %823 = shape.shape_of %815 : tensor -> tensor<2xindex> + %824 = shape.shape_of %822 : tensor -> tensor<2xindex> + %825 = shape.cstr_broadcastable %823, %824 : tensor<2xindex>, tensor<2xindex> + %826 = shape.assuming %825 -> (tensor) { + %19688 = shape.broadcast %823, %824 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %815, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %822, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %827 = shape.shape_of %826 : tensor -> tensor<2xindex> + %828 = stablehlo.dynamic_broadcast_in_dim %826, %827, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %829 = stablehlo.dynamic_broadcast_in_dim %213, %827, dims = [] : (tensor, tensor<2xindex>) -> tensor + %830 = stablehlo.multiply %828, %829 : tensor + %dim_261 = tensor.dim %789, %c0 : tensor + %831 = arith.index_cast %dim_261 : index to i64 + %dim_262 = tensor.dim %826, %c0 : tensor + %832 = arith.index_cast %dim_262 : index to i64 + %833 = arith.maxsi %831, %832 : i64 + %834 = arith.index_cast %833 : i64 to index + %from_elements_263 = tensor.from_elements %834, %c4096 : tensor<2xindex> + %835 = stablehlo.dynamic_broadcast_in_dim %789, %from_elements_263, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_264 = tensor.dim %835, %c0 : tensor + %836 = arith.index_cast %dim_264 : index to i64 + %from_elements_265 = tensor.from_elements %836, %c4096_i64 : tensor<2xi64> + %837 = stablehlo.real_dynamic_slice %830, %c_22, %from_elements_265, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_266 = tensor.from_elements %836, %c4096_i64, %c1_i64 : tensor<3xi64> + %838 = stablehlo.dynamic_reshape %835, %from_elements_266 : (tensor, tensor<3xi64>) -> tensor + %839 = stablehlo.dynamic_iota %from_elements_266, dim = 1 : (tensor<3xi64>) -> tensor + %840 = stablehlo.concatenate %838, %839, dim = 2 : (tensor, tensor) -> tensor + %841 = "stablehlo.scatter"(%cst_2, %840, %837) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %842 = stablehlo.slice %776 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %843 = stablehlo.reshape %842 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %844 = stablehlo.custom_call @byteir.non_zero(%843) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_267 = tensor.dim %844, %c0 : tensor + %845 = arith.index_cast %dim_267 : index to i64 + %from_elements_268 = tensor.from_elements %845, %c1_i64 : tensor<2xi64> + %846 = stablehlo.real_dynamic_slice %844, %c_22, %from_elements_268, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_269 = tensor.dim %846, %c0 : tensor + %847 = arith.index_cast %dim_269 : index to i64 + %from_elements_270 = tensor.from_elements %847 : tensor<1xi64> + %848 = stablehlo.dynamic_reshape %846, %from_elements_270 : (tensor, tensor<1xi64>) -> tensor + %from_elements_271 = tensor.from_elements %845, %c2_i64 : tensor<2xi64> + %849 = stablehlo.real_dynamic_slice %844, %c_24, %from_elements_271, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_272 = tensor.dim %849, %c0 : tensor + %850 = arith.index_cast %dim_272 : index to i64 + %from_elements_273 = tensor.from_elements %850 : tensor<1xi64> + %851 = stablehlo.dynamic_reshape %849, %from_elements_273 : (tensor, tensor<1xi64>) -> tensor + %dim_274 = tensor.dim %851, %c0 : tensor + %852 = arith.index_cast %dim_274 : index to i64 + %from_elements_275 = tensor.from_elements %852, %c1_i64 : tensor<2xi64> + %853 = stablehlo.dynamic_reshape %851, %from_elements_275 : (tensor, tensor<2xi64>) -> tensor + %dim_276 = tensor.dim %853, %c0 : tensor + %854 = arith.index_cast %dim_276 : index to i64 + %from_elements_277 = tensor.from_elements %c1_i64, %854, %c4096_i64 : tensor<3xi64> + %855 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_277, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_278 = tensor.dim %855, %c1 : tensor<1x?x4096xi64> + %856 = arith.index_cast %dim_278 : index to i64 + %from_elements_279 = tensor.from_elements %c1_i64, %856, %c4096_i64, %c1_i64 : tensor<4xi64> + %857 = stablehlo.dynamic_reshape %855, %from_elements_279 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %858 = stablehlo.dynamic_broadcast_in_dim %853, %from_elements_277, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_280 = tensor.dim %858, %c1 : tensor<1x?x4096xi64> + %859 = arith.index_cast %dim_280 : index to i64 + %from_elements_281 = tensor.from_elements %c1_i64, %859, %c4096_i64, %c1_i64 : tensor<4xi64> + %860 = stablehlo.dynamic_reshape %858, %from_elements_281 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %861 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_277, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_282 = tensor.dim %861, %c1 : tensor<1x?x4096xi64> + %862 = arith.index_cast %dim_282 : index to i64 + %from_elements_283 = tensor.from_elements %c1_i64, %862, %c4096_i64, %c1_i64 : tensor<4xi64> + %863 = stablehlo.dynamic_reshape %861, %from_elements_283 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %864 = stablehlo.concatenate %857, %860, %863, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %865 = "stablehlo.gather"(%787, %864) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %866 = shape.shape_of %865 : tensor<1x?x4096xf32> -> tensor<3xindex> + %867 = shape.num_elements %866 : tensor<3xindex> -> index + %868 = stablehlo.compute_reshape_shape %867, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %869 = stablehlo.dynamic_reshape %865, %868 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %870 = stablehlo.dot %869, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %871 = stablehlo.logistic %870 : tensor + %872 = shape.shape_of %871 : tensor -> tensor<2xindex> + %873 = shape.shape_of %870 : tensor -> tensor<2xindex> + %874 = shape.cstr_broadcastable %872, %873 : tensor<2xindex>, tensor<2xindex> + %875 = shape.assuming %874 -> (tensor) { + %19688 = shape.broadcast %872, %873 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %871, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %870, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %876 = shape.shape_of %875 : tensor -> tensor<2xindex> + %877 = shape.cstr_broadcastable %876, %873 : tensor<2xindex>, tensor<2xindex> + %878 = shape.assuming %877 -> (tensor) { + %19688 = shape.broadcast %876, %873 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %875, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %870, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %879 = stablehlo.dot %878, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_284 = tensor.dim %851, %c0 : tensor + %880 = arith.index_cast %dim_284 : index to i64 + %from_elements_285 = tensor.from_elements %880, %c1_i64 : tensor<2xi64> + %881 = stablehlo.dynamic_reshape %851, %from_elements_285 : (tensor, tensor<2xi64>) -> tensor + %dim_286 = tensor.dim %848, %c0 : tensor + %882 = arith.index_cast %dim_286 : index to i64 + %from_elements_287 = tensor.from_elements %882, %c1_i64 : tensor<2xi64> + %883 = stablehlo.dynamic_reshape %848, %from_elements_287 : (tensor, tensor<2xi64>) -> tensor + %884 = stablehlo.concatenate %881, %883, dim = 1 : (tensor, tensor) -> tensor + %885 = "stablehlo.gather"(%816, %884) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %886 = shape.shape_of %879 : tensor -> tensor<2xindex> + %887 = shape.shape_of %885 : tensor -> tensor<2xindex> + %888 = shape.cstr_broadcastable %886, %887 : tensor<2xindex>, tensor<2xindex> + %889 = shape.assuming %888 -> (tensor) { + %19688 = shape.broadcast %886, %887 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %879, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %885, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %890 = shape.shape_of %889 : tensor -> tensor<2xindex> + %891 = stablehlo.dynamic_broadcast_in_dim %889, %890, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %892 = stablehlo.dynamic_broadcast_in_dim %213, %890, dims = [] : (tensor, tensor<2xindex>) -> tensor + %893 = stablehlo.multiply %891, %892 : tensor + %dim_288 = tensor.dim %853, %c0 : tensor + %894 = arith.index_cast %dim_288 : index to i64 + %dim_289 = tensor.dim %889, %c0 : tensor + %895 = arith.index_cast %dim_289 : index to i64 + %896 = arith.maxsi %894, %895 : i64 + %897 = arith.index_cast %896 : i64 to index + %from_elements_290 = tensor.from_elements %897, %c4096 : tensor<2xindex> + %898 = stablehlo.dynamic_broadcast_in_dim %853, %from_elements_290, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_291 = tensor.dim %898, %c0 : tensor + %899 = arith.index_cast %dim_291 : index to i64 + %from_elements_292 = tensor.from_elements %899, %c4096_i64 : tensor<2xi64> + %900 = stablehlo.real_dynamic_slice %893, %c_22, %from_elements_292, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_293 = tensor.from_elements %899, %c4096_i64, %c1_i64 : tensor<3xi64> + %901 = stablehlo.dynamic_reshape %898, %from_elements_293 : (tensor, tensor<3xi64>) -> tensor + %902 = stablehlo.dynamic_iota %from_elements_293, dim = 1 : (tensor<3xi64>) -> tensor + %903 = stablehlo.concatenate %901, %902, dim = 2 : (tensor, tensor) -> tensor + %904 = "stablehlo.scatter"(%841, %903, %900) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %905 = stablehlo.slice %776 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %906 = stablehlo.reshape %905 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %907 = stablehlo.custom_call @byteir.non_zero(%906) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_294 = tensor.dim %907, %c0 : tensor + %908 = arith.index_cast %dim_294 : index to i64 + %from_elements_295 = tensor.from_elements %908, %c1_i64 : tensor<2xi64> + %909 = stablehlo.real_dynamic_slice %907, %c_22, %from_elements_295, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_296 = tensor.dim %909, %c0 : tensor + %910 = arith.index_cast %dim_296 : index to i64 + %from_elements_297 = tensor.from_elements %910 : tensor<1xi64> + %911 = stablehlo.dynamic_reshape %909, %from_elements_297 : (tensor, tensor<1xi64>) -> tensor + %from_elements_298 = tensor.from_elements %908, %c2_i64 : tensor<2xi64> + %912 = stablehlo.real_dynamic_slice %907, %c_24, %from_elements_298, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_299 = tensor.dim %912, %c0 : tensor + %913 = arith.index_cast %dim_299 : index to i64 + %from_elements_300 = tensor.from_elements %913 : tensor<1xi64> + %914 = stablehlo.dynamic_reshape %912, %from_elements_300 : (tensor, tensor<1xi64>) -> tensor + %dim_301 = tensor.dim %914, %c0 : tensor + %915 = arith.index_cast %dim_301 : index to i64 + %from_elements_302 = tensor.from_elements %915, %c1_i64 : tensor<2xi64> + %916 = stablehlo.dynamic_reshape %914, %from_elements_302 : (tensor, tensor<2xi64>) -> tensor + %dim_303 = tensor.dim %916, %c0 : tensor + %917 = arith.index_cast %dim_303 : index to i64 + %from_elements_304 = tensor.from_elements %c1_i64, %917, %c4096_i64 : tensor<3xi64> + %918 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_304, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_305 = tensor.dim %918, %c1 : tensor<1x?x4096xi64> + %919 = arith.index_cast %dim_305 : index to i64 + %from_elements_306 = tensor.from_elements %c1_i64, %919, %c4096_i64, %c1_i64 : tensor<4xi64> + %920 = stablehlo.dynamic_reshape %918, %from_elements_306 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %921 = stablehlo.dynamic_broadcast_in_dim %916, %from_elements_304, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_307 = tensor.dim %921, %c1 : tensor<1x?x4096xi64> + %922 = arith.index_cast %dim_307 : index to i64 + %from_elements_308 = tensor.from_elements %c1_i64, %922, %c4096_i64, %c1_i64 : tensor<4xi64> + %923 = stablehlo.dynamic_reshape %921, %from_elements_308 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %924 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_304, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_309 = tensor.dim %924, %c1 : tensor<1x?x4096xi64> + %925 = arith.index_cast %dim_309 : index to i64 + %from_elements_310 = tensor.from_elements %c1_i64, %925, %c4096_i64, %c1_i64 : tensor<4xi64> + %926 = stablehlo.dynamic_reshape %924, %from_elements_310 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %927 = stablehlo.concatenate %920, %923, %926, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %928 = "stablehlo.gather"(%787, %927) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %929 = shape.shape_of %928 : tensor<1x?x4096xf32> -> tensor<3xindex> + %930 = shape.num_elements %929 : tensor<3xindex> -> index + %931 = stablehlo.compute_reshape_shape %930, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %932 = stablehlo.dynamic_reshape %928, %931 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %933 = stablehlo.dot %932, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %934 = stablehlo.logistic %933 : tensor + %935 = shape.shape_of %934 : tensor -> tensor<2xindex> + %936 = shape.shape_of %933 : tensor -> tensor<2xindex> + %937 = shape.cstr_broadcastable %935, %936 : tensor<2xindex>, tensor<2xindex> + %938 = shape.assuming %937 -> (tensor) { + %19688 = shape.broadcast %935, %936 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %934, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %933, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %939 = shape.shape_of %938 : tensor -> tensor<2xindex> + %940 = shape.cstr_broadcastable %939, %936 : tensor<2xindex>, tensor<2xindex> + %941 = shape.assuming %940 -> (tensor) { + %19688 = shape.broadcast %939, %936 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %938, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %933, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %942 = stablehlo.dot %941, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_311 = tensor.dim %914, %c0 : tensor + %943 = arith.index_cast %dim_311 : index to i64 + %from_elements_312 = tensor.from_elements %943, %c1_i64 : tensor<2xi64> + %944 = stablehlo.dynamic_reshape %914, %from_elements_312 : (tensor, tensor<2xi64>) -> tensor + %dim_313 = tensor.dim %911, %c0 : tensor + %945 = arith.index_cast %dim_313 : index to i64 + %from_elements_314 = tensor.from_elements %945, %c1_i64 : tensor<2xi64> + %946 = stablehlo.dynamic_reshape %911, %from_elements_314 : (tensor, tensor<2xi64>) -> tensor + %947 = stablehlo.concatenate %944, %946, dim = 1 : (tensor, tensor) -> tensor + %948 = "stablehlo.gather"(%816, %947) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %949 = shape.shape_of %942 : tensor -> tensor<2xindex> + %950 = shape.shape_of %948 : tensor -> tensor<2xindex> + %951 = shape.cstr_broadcastable %949, %950 : tensor<2xindex>, tensor<2xindex> + %952 = shape.assuming %951 -> (tensor) { + %19688 = shape.broadcast %949, %950 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %942, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %948, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %953 = shape.shape_of %952 : tensor -> tensor<2xindex> + %954 = stablehlo.dynamic_broadcast_in_dim %952, %953, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %955 = stablehlo.dynamic_broadcast_in_dim %213, %953, dims = [] : (tensor, tensor<2xindex>) -> tensor + %956 = stablehlo.multiply %954, %955 : tensor + %dim_315 = tensor.dim %916, %c0 : tensor + %957 = arith.index_cast %dim_315 : index to i64 + %dim_316 = tensor.dim %952, %c0 : tensor + %958 = arith.index_cast %dim_316 : index to i64 + %959 = arith.maxsi %957, %958 : i64 + %960 = arith.index_cast %959 : i64 to index + %from_elements_317 = tensor.from_elements %960, %c4096 : tensor<2xindex> + %961 = stablehlo.dynamic_broadcast_in_dim %916, %from_elements_317, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_318 = tensor.dim %961, %c0 : tensor + %962 = arith.index_cast %dim_318 : index to i64 + %from_elements_319 = tensor.from_elements %962, %c4096_i64 : tensor<2xi64> + %963 = stablehlo.real_dynamic_slice %956, %c_22, %from_elements_319, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_320 = tensor.from_elements %962, %c4096_i64, %c1_i64 : tensor<3xi64> + %964 = stablehlo.dynamic_reshape %961, %from_elements_320 : (tensor, tensor<3xi64>) -> tensor + %965 = stablehlo.dynamic_iota %from_elements_320, dim = 1 : (tensor<3xi64>) -> tensor + %966 = stablehlo.concatenate %964, %965, dim = 2 : (tensor, tensor) -> tensor + %967 = "stablehlo.scatter"(%904, %966, %963) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %968 = stablehlo.slice %776 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %969 = stablehlo.reshape %968 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %970 = stablehlo.custom_call @byteir.non_zero(%969) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_321 = tensor.dim %970, %c0 : tensor + %971 = arith.index_cast %dim_321 : index to i64 + %from_elements_322 = tensor.from_elements %971, %c1_i64 : tensor<2xi64> + %972 = stablehlo.real_dynamic_slice %970, %c_22, %from_elements_322, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_323 = tensor.dim %972, %c0 : tensor + %973 = arith.index_cast %dim_323 : index to i64 + %from_elements_324 = tensor.from_elements %973 : tensor<1xi64> + %974 = stablehlo.dynamic_reshape %972, %from_elements_324 : (tensor, tensor<1xi64>) -> tensor + %from_elements_325 = tensor.from_elements %971, %c2_i64 : tensor<2xi64> + %975 = stablehlo.real_dynamic_slice %970, %c_24, %from_elements_325, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_326 = tensor.dim %975, %c0 : tensor + %976 = arith.index_cast %dim_326 : index to i64 + %from_elements_327 = tensor.from_elements %976 : tensor<1xi64> + %977 = stablehlo.dynamic_reshape %975, %from_elements_327 : (tensor, tensor<1xi64>) -> tensor + %dim_328 = tensor.dim %977, %c0 : tensor + %978 = arith.index_cast %dim_328 : index to i64 + %from_elements_329 = tensor.from_elements %978, %c1_i64 : tensor<2xi64> + %979 = stablehlo.dynamic_reshape %977, %from_elements_329 : (tensor, tensor<2xi64>) -> tensor + %dim_330 = tensor.dim %979, %c0 : tensor + %980 = arith.index_cast %dim_330 : index to i64 + %from_elements_331 = tensor.from_elements %c1_i64, %980, %c4096_i64 : tensor<3xi64> + %981 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_331, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_332 = tensor.dim %981, %c1 : tensor<1x?x4096xi64> + %982 = arith.index_cast %dim_332 : index to i64 + %from_elements_333 = tensor.from_elements %c1_i64, %982, %c4096_i64, %c1_i64 : tensor<4xi64> + %983 = stablehlo.dynamic_reshape %981, %from_elements_333 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %984 = stablehlo.dynamic_broadcast_in_dim %979, %from_elements_331, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_334 = tensor.dim %984, %c1 : tensor<1x?x4096xi64> + %985 = arith.index_cast %dim_334 : index to i64 + %from_elements_335 = tensor.from_elements %c1_i64, %985, %c4096_i64, %c1_i64 : tensor<4xi64> + %986 = stablehlo.dynamic_reshape %984, %from_elements_335 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %987 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_331, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_336 = tensor.dim %987, %c1 : tensor<1x?x4096xi64> + %988 = arith.index_cast %dim_336 : index to i64 + %from_elements_337 = tensor.from_elements %c1_i64, %988, %c4096_i64, %c1_i64 : tensor<4xi64> + %989 = stablehlo.dynamic_reshape %987, %from_elements_337 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %990 = stablehlo.concatenate %983, %986, %989, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %991 = "stablehlo.gather"(%787, %990) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %992 = shape.shape_of %991 : tensor<1x?x4096xf32> -> tensor<3xindex> + %993 = shape.num_elements %992 : tensor<3xindex> -> index + %994 = stablehlo.compute_reshape_shape %993, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %995 = stablehlo.dynamic_reshape %991, %994 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %996 = stablehlo.dot %995, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %997 = stablehlo.logistic %996 : tensor + %998 = shape.shape_of %997 : tensor -> tensor<2xindex> + %999 = shape.shape_of %996 : tensor -> tensor<2xindex> + %1000 = shape.cstr_broadcastable %998, %999 : tensor<2xindex>, tensor<2xindex> + %1001 = shape.assuming %1000 -> (tensor) { + %19688 = shape.broadcast %998, %999 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %997, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %996, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1002 = shape.shape_of %1001 : tensor -> tensor<2xindex> + %1003 = shape.cstr_broadcastable %1002, %999 : tensor<2xindex>, tensor<2xindex> + %1004 = shape.assuming %1003 -> (tensor) { + %19688 = shape.broadcast %1002, %999 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1001, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %996, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1005 = stablehlo.dot %1004, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_338 = tensor.dim %977, %c0 : tensor + %1006 = arith.index_cast %dim_338 : index to i64 + %from_elements_339 = tensor.from_elements %1006, %c1_i64 : tensor<2xi64> + %1007 = stablehlo.dynamic_reshape %977, %from_elements_339 : (tensor, tensor<2xi64>) -> tensor + %dim_340 = tensor.dim %974, %c0 : tensor + %1008 = arith.index_cast %dim_340 : index to i64 + %from_elements_341 = tensor.from_elements %1008, %c1_i64 : tensor<2xi64> + %1009 = stablehlo.dynamic_reshape %974, %from_elements_341 : (tensor, tensor<2xi64>) -> tensor + %1010 = stablehlo.concatenate %1007, %1009, dim = 1 : (tensor, tensor) -> tensor + %1011 = "stablehlo.gather"(%816, %1010) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1012 = shape.shape_of %1005 : tensor -> tensor<2xindex> + %1013 = shape.shape_of %1011 : tensor -> tensor<2xindex> + %1014 = shape.cstr_broadcastable %1012, %1013 : tensor<2xindex>, tensor<2xindex> + %1015 = shape.assuming %1014 -> (tensor) { + %19688 = shape.broadcast %1012, %1013 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1005, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1011, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1016 = shape.shape_of %1015 : tensor -> tensor<2xindex> + %1017 = stablehlo.dynamic_broadcast_in_dim %1015, %1016, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1018 = stablehlo.dynamic_broadcast_in_dim %213, %1016, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1019 = stablehlo.multiply %1017, %1018 : tensor + %dim_342 = tensor.dim %979, %c0 : tensor + %1020 = arith.index_cast %dim_342 : index to i64 + %dim_343 = tensor.dim %1015, %c0 : tensor + %1021 = arith.index_cast %dim_343 : index to i64 + %1022 = arith.maxsi %1020, %1021 : i64 + %1023 = arith.index_cast %1022 : i64 to index + %from_elements_344 = tensor.from_elements %1023, %c4096 : tensor<2xindex> + %1024 = stablehlo.dynamic_broadcast_in_dim %979, %from_elements_344, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_345 = tensor.dim %1024, %c0 : tensor + %1025 = arith.index_cast %dim_345 : index to i64 + %from_elements_346 = tensor.from_elements %1025, %c4096_i64 : tensor<2xi64> + %1026 = stablehlo.real_dynamic_slice %1019, %c_22, %from_elements_346, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_347 = tensor.from_elements %1025, %c4096_i64, %c1_i64 : tensor<3xi64> + %1027 = stablehlo.dynamic_reshape %1024, %from_elements_347 : (tensor, tensor<3xi64>) -> tensor + %1028 = stablehlo.dynamic_iota %from_elements_347, dim = 1 : (tensor<3xi64>) -> tensor + %1029 = stablehlo.concatenate %1027, %1028, dim = 2 : (tensor, tensor) -> tensor + %1030 = "stablehlo.scatter"(%967, %1029, %1026) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1031 = stablehlo.slice %776 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1032 = stablehlo.reshape %1031 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1033 = stablehlo.custom_call @byteir.non_zero(%1032) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_348 = tensor.dim %1033, %c0 : tensor + %1034 = arith.index_cast %dim_348 : index to i64 + %from_elements_349 = tensor.from_elements %1034, %c1_i64 : tensor<2xi64> + %1035 = stablehlo.real_dynamic_slice %1033, %c_22, %from_elements_349, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_350 = tensor.dim %1035, %c0 : tensor + %1036 = arith.index_cast %dim_350 : index to i64 + %from_elements_351 = tensor.from_elements %1036 : tensor<1xi64> + %1037 = stablehlo.dynamic_reshape %1035, %from_elements_351 : (tensor, tensor<1xi64>) -> tensor + %from_elements_352 = tensor.from_elements %1034, %c2_i64 : tensor<2xi64> + %1038 = stablehlo.real_dynamic_slice %1033, %c_24, %from_elements_352, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_353 = tensor.dim %1038, %c0 : tensor + %1039 = arith.index_cast %dim_353 : index to i64 + %from_elements_354 = tensor.from_elements %1039 : tensor<1xi64> + %1040 = stablehlo.dynamic_reshape %1038, %from_elements_354 : (tensor, tensor<1xi64>) -> tensor + %dim_355 = tensor.dim %1040, %c0 : tensor + %1041 = arith.index_cast %dim_355 : index to i64 + %from_elements_356 = tensor.from_elements %1041, %c1_i64 : tensor<2xi64> + %1042 = stablehlo.dynamic_reshape %1040, %from_elements_356 : (tensor, tensor<2xi64>) -> tensor + %dim_357 = tensor.dim %1042, %c0 : tensor + %1043 = arith.index_cast %dim_357 : index to i64 + %from_elements_358 = tensor.from_elements %c1_i64, %1043, %c4096_i64 : tensor<3xi64> + %1044 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_358, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_359 = tensor.dim %1044, %c1 : tensor<1x?x4096xi64> + %1045 = arith.index_cast %dim_359 : index to i64 + %from_elements_360 = tensor.from_elements %c1_i64, %1045, %c4096_i64, %c1_i64 : tensor<4xi64> + %1046 = stablehlo.dynamic_reshape %1044, %from_elements_360 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1047 = stablehlo.dynamic_broadcast_in_dim %1042, %from_elements_358, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_361 = tensor.dim %1047, %c1 : tensor<1x?x4096xi64> + %1048 = arith.index_cast %dim_361 : index to i64 + %from_elements_362 = tensor.from_elements %c1_i64, %1048, %c4096_i64, %c1_i64 : tensor<4xi64> + %1049 = stablehlo.dynamic_reshape %1047, %from_elements_362 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1050 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_358, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_363 = tensor.dim %1050, %c1 : tensor<1x?x4096xi64> + %1051 = arith.index_cast %dim_363 : index to i64 + %from_elements_364 = tensor.from_elements %c1_i64, %1051, %c4096_i64, %c1_i64 : tensor<4xi64> + %1052 = stablehlo.dynamic_reshape %1050, %from_elements_364 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1053 = stablehlo.concatenate %1046, %1049, %1052, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1054 = "stablehlo.gather"(%787, %1053) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1055 = shape.shape_of %1054 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1056 = shape.num_elements %1055 : tensor<3xindex> -> index + %1057 = stablehlo.compute_reshape_shape %1056, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1058 = stablehlo.dynamic_reshape %1054, %1057 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1059 = stablehlo.dot %1058, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1060 = stablehlo.logistic %1059 : tensor + %1061 = shape.shape_of %1060 : tensor -> tensor<2xindex> + %1062 = shape.shape_of %1059 : tensor -> tensor<2xindex> + %1063 = shape.cstr_broadcastable %1061, %1062 : tensor<2xindex>, tensor<2xindex> + %1064 = shape.assuming %1063 -> (tensor) { + %19688 = shape.broadcast %1061, %1062 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1060, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1059, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1065 = shape.shape_of %1064 : tensor -> tensor<2xindex> + %1066 = shape.cstr_broadcastable %1065, %1062 : tensor<2xindex>, tensor<2xindex> + %1067 = shape.assuming %1066 -> (tensor) { + %19688 = shape.broadcast %1065, %1062 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1064, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1059, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1068 = stablehlo.dot %1067, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_365 = tensor.dim %1040, %c0 : tensor + %1069 = arith.index_cast %dim_365 : index to i64 + %from_elements_366 = tensor.from_elements %1069, %c1_i64 : tensor<2xi64> + %1070 = stablehlo.dynamic_reshape %1040, %from_elements_366 : (tensor, tensor<2xi64>) -> tensor + %dim_367 = tensor.dim %1037, %c0 : tensor + %1071 = arith.index_cast %dim_367 : index to i64 + %from_elements_368 = tensor.from_elements %1071, %c1_i64 : tensor<2xi64> + %1072 = stablehlo.dynamic_reshape %1037, %from_elements_368 : (tensor, tensor<2xi64>) -> tensor + %1073 = stablehlo.concatenate %1070, %1072, dim = 1 : (tensor, tensor) -> tensor + %1074 = "stablehlo.gather"(%816, %1073) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1075 = shape.shape_of %1068 : tensor -> tensor<2xindex> + %1076 = shape.shape_of %1074 : tensor -> tensor<2xindex> + %1077 = shape.cstr_broadcastable %1075, %1076 : tensor<2xindex>, tensor<2xindex> + %1078 = shape.assuming %1077 -> (tensor) { + %19688 = shape.broadcast %1075, %1076 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1068, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1074, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1079 = shape.shape_of %1078 : tensor -> tensor<2xindex> + %1080 = stablehlo.dynamic_broadcast_in_dim %1078, %1079, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1081 = stablehlo.dynamic_broadcast_in_dim %213, %1079, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1082 = stablehlo.multiply %1080, %1081 : tensor + %dim_369 = tensor.dim %1042, %c0 : tensor + %1083 = arith.index_cast %dim_369 : index to i64 + %dim_370 = tensor.dim %1078, %c0 : tensor + %1084 = arith.index_cast %dim_370 : index to i64 + %1085 = arith.maxsi %1083, %1084 : i64 + %1086 = arith.index_cast %1085 : i64 to index + %from_elements_371 = tensor.from_elements %1086, %c4096 : tensor<2xindex> + %1087 = stablehlo.dynamic_broadcast_in_dim %1042, %from_elements_371, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_372 = tensor.dim %1087, %c0 : tensor + %1088 = arith.index_cast %dim_372 : index to i64 + %from_elements_373 = tensor.from_elements %1088, %c4096_i64 : tensor<2xi64> + %1089 = stablehlo.real_dynamic_slice %1082, %c_22, %from_elements_373, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_374 = tensor.from_elements %1088, %c4096_i64, %c1_i64 : tensor<3xi64> + %1090 = stablehlo.dynamic_reshape %1087, %from_elements_374 : (tensor, tensor<3xi64>) -> tensor + %1091 = stablehlo.dynamic_iota %from_elements_374, dim = 1 : (tensor<3xi64>) -> tensor + %1092 = stablehlo.concatenate %1090, %1091, dim = 2 : (tensor, tensor) -> tensor + %1093 = "stablehlo.scatter"(%1030, %1092, %1089) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1094 = stablehlo.slice %776 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1095 = stablehlo.reshape %1094 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1096 = stablehlo.custom_call @byteir.non_zero(%1095) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_375 = tensor.dim %1096, %c0 : tensor + %1097 = arith.index_cast %dim_375 : index to i64 + %from_elements_376 = tensor.from_elements %1097, %c1_i64 : tensor<2xi64> + %1098 = stablehlo.real_dynamic_slice %1096, %c_22, %from_elements_376, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_377 = tensor.dim %1098, %c0 : tensor + %1099 = arith.index_cast %dim_377 : index to i64 + %from_elements_378 = tensor.from_elements %1099 : tensor<1xi64> + %1100 = stablehlo.dynamic_reshape %1098, %from_elements_378 : (tensor, tensor<1xi64>) -> tensor + %from_elements_379 = tensor.from_elements %1097, %c2_i64 : tensor<2xi64> + %1101 = stablehlo.real_dynamic_slice %1096, %c_24, %from_elements_379, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_380 = tensor.dim %1101, %c0 : tensor + %1102 = arith.index_cast %dim_380 : index to i64 + %from_elements_381 = tensor.from_elements %1102 : tensor<1xi64> + %1103 = stablehlo.dynamic_reshape %1101, %from_elements_381 : (tensor, tensor<1xi64>) -> tensor + %dim_382 = tensor.dim %1103, %c0 : tensor + %1104 = arith.index_cast %dim_382 : index to i64 + %from_elements_383 = tensor.from_elements %1104, %c1_i64 : tensor<2xi64> + %1105 = stablehlo.dynamic_reshape %1103, %from_elements_383 : (tensor, tensor<2xi64>) -> tensor + %dim_384 = tensor.dim %1105, %c0 : tensor + %1106 = arith.index_cast %dim_384 : index to i64 + %from_elements_385 = tensor.from_elements %c1_i64, %1106, %c4096_i64 : tensor<3xi64> + %1107 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_385, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_386 = tensor.dim %1107, %c1 : tensor<1x?x4096xi64> + %1108 = arith.index_cast %dim_386 : index to i64 + %from_elements_387 = tensor.from_elements %c1_i64, %1108, %c4096_i64, %c1_i64 : tensor<4xi64> + %1109 = stablehlo.dynamic_reshape %1107, %from_elements_387 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1110 = stablehlo.dynamic_broadcast_in_dim %1105, %from_elements_385, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_388 = tensor.dim %1110, %c1 : tensor<1x?x4096xi64> + %1111 = arith.index_cast %dim_388 : index to i64 + %from_elements_389 = tensor.from_elements %c1_i64, %1111, %c4096_i64, %c1_i64 : tensor<4xi64> + %1112 = stablehlo.dynamic_reshape %1110, %from_elements_389 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1113 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_385, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_390 = tensor.dim %1113, %c1 : tensor<1x?x4096xi64> + %1114 = arith.index_cast %dim_390 : index to i64 + %from_elements_391 = tensor.from_elements %c1_i64, %1114, %c4096_i64, %c1_i64 : tensor<4xi64> + %1115 = stablehlo.dynamic_reshape %1113, %from_elements_391 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1116 = stablehlo.concatenate %1109, %1112, %1115, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1117 = "stablehlo.gather"(%787, %1116) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1118 = shape.shape_of %1117 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1119 = shape.num_elements %1118 : tensor<3xindex> -> index + %1120 = stablehlo.compute_reshape_shape %1119, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1121 = stablehlo.dynamic_reshape %1117, %1120 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1122 = stablehlo.dot %1121, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1123 = stablehlo.logistic %1122 : tensor + %1124 = shape.shape_of %1123 : tensor -> tensor<2xindex> + %1125 = shape.shape_of %1122 : tensor -> tensor<2xindex> + %1126 = shape.cstr_broadcastable %1124, %1125 : tensor<2xindex>, tensor<2xindex> + %1127 = shape.assuming %1126 -> (tensor) { + %19688 = shape.broadcast %1124, %1125 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1123, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1122, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1128 = shape.shape_of %1127 : tensor -> tensor<2xindex> + %1129 = shape.cstr_broadcastable %1128, %1125 : tensor<2xindex>, tensor<2xindex> + %1130 = shape.assuming %1129 -> (tensor) { + %19688 = shape.broadcast %1128, %1125 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1127, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1122, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1131 = stablehlo.dot %1130, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_392 = tensor.dim %1103, %c0 : tensor + %1132 = arith.index_cast %dim_392 : index to i64 + %from_elements_393 = tensor.from_elements %1132, %c1_i64 : tensor<2xi64> + %1133 = stablehlo.dynamic_reshape %1103, %from_elements_393 : (tensor, tensor<2xi64>) -> tensor + %dim_394 = tensor.dim %1100, %c0 : tensor + %1134 = arith.index_cast %dim_394 : index to i64 + %from_elements_395 = tensor.from_elements %1134, %c1_i64 : tensor<2xi64> + %1135 = stablehlo.dynamic_reshape %1100, %from_elements_395 : (tensor, tensor<2xi64>) -> tensor + %1136 = stablehlo.concatenate %1133, %1135, dim = 1 : (tensor, tensor) -> tensor + %1137 = "stablehlo.gather"(%816, %1136) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1138 = shape.shape_of %1131 : tensor -> tensor<2xindex> + %1139 = shape.shape_of %1137 : tensor -> tensor<2xindex> + %1140 = shape.cstr_broadcastable %1138, %1139 : tensor<2xindex>, tensor<2xindex> + %1141 = shape.assuming %1140 -> (tensor) { + %19688 = shape.broadcast %1138, %1139 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1131, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1137, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1142 = shape.shape_of %1141 : tensor -> tensor<2xindex> + %1143 = stablehlo.dynamic_broadcast_in_dim %1141, %1142, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1144 = stablehlo.dynamic_broadcast_in_dim %213, %1142, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1145 = stablehlo.multiply %1143, %1144 : tensor + %dim_396 = tensor.dim %1105, %c0 : tensor + %1146 = arith.index_cast %dim_396 : index to i64 + %dim_397 = tensor.dim %1141, %c0 : tensor + %1147 = arith.index_cast %dim_397 : index to i64 + %1148 = arith.maxsi %1146, %1147 : i64 + %1149 = arith.index_cast %1148 : i64 to index + %from_elements_398 = tensor.from_elements %1149, %c4096 : tensor<2xindex> + %1150 = stablehlo.dynamic_broadcast_in_dim %1105, %from_elements_398, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_399 = tensor.dim %1150, %c0 : tensor + %1151 = arith.index_cast %dim_399 : index to i64 + %from_elements_400 = tensor.from_elements %1151, %c4096_i64 : tensor<2xi64> + %1152 = stablehlo.real_dynamic_slice %1145, %c_22, %from_elements_400, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_401 = tensor.from_elements %1151, %c4096_i64, %c1_i64 : tensor<3xi64> + %1153 = stablehlo.dynamic_reshape %1150, %from_elements_401 : (tensor, tensor<3xi64>) -> tensor + %1154 = stablehlo.dynamic_iota %from_elements_401, dim = 1 : (tensor<3xi64>) -> tensor + %1155 = stablehlo.concatenate %1153, %1154, dim = 2 : (tensor, tensor) -> tensor + %1156 = "stablehlo.scatter"(%1093, %1155, %1152) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1157 = stablehlo.slice %776 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1158 = stablehlo.reshape %1157 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1159 = stablehlo.custom_call @byteir.non_zero(%1158) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_402 = tensor.dim %1159, %c0 : tensor + %1160 = arith.index_cast %dim_402 : index to i64 + %from_elements_403 = tensor.from_elements %1160, %c1_i64 : tensor<2xi64> + %1161 = stablehlo.real_dynamic_slice %1159, %c_22, %from_elements_403, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_404 = tensor.dim %1161, %c0 : tensor + %1162 = arith.index_cast %dim_404 : index to i64 + %from_elements_405 = tensor.from_elements %1162 : tensor<1xi64> + %1163 = stablehlo.dynamic_reshape %1161, %from_elements_405 : (tensor, tensor<1xi64>) -> tensor + %from_elements_406 = tensor.from_elements %1160, %c2_i64 : tensor<2xi64> + %1164 = stablehlo.real_dynamic_slice %1159, %c_24, %from_elements_406, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_407 = tensor.dim %1164, %c0 : tensor + %1165 = arith.index_cast %dim_407 : index to i64 + %from_elements_408 = tensor.from_elements %1165 : tensor<1xi64> + %1166 = stablehlo.dynamic_reshape %1164, %from_elements_408 : (tensor, tensor<1xi64>) -> tensor + %dim_409 = tensor.dim %1166, %c0 : tensor + %1167 = arith.index_cast %dim_409 : index to i64 + %from_elements_410 = tensor.from_elements %1167, %c1_i64 : tensor<2xi64> + %1168 = stablehlo.dynamic_reshape %1166, %from_elements_410 : (tensor, tensor<2xi64>) -> tensor + %dim_411 = tensor.dim %1168, %c0 : tensor + %1169 = arith.index_cast %dim_411 : index to i64 + %from_elements_412 = tensor.from_elements %c1_i64, %1169, %c4096_i64 : tensor<3xi64> + %1170 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_412, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_413 = tensor.dim %1170, %c1 : tensor<1x?x4096xi64> + %1171 = arith.index_cast %dim_413 : index to i64 + %from_elements_414 = tensor.from_elements %c1_i64, %1171, %c4096_i64, %c1_i64 : tensor<4xi64> + %1172 = stablehlo.dynamic_reshape %1170, %from_elements_414 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1173 = stablehlo.dynamic_broadcast_in_dim %1168, %from_elements_412, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_415 = tensor.dim %1173, %c1 : tensor<1x?x4096xi64> + %1174 = arith.index_cast %dim_415 : index to i64 + %from_elements_416 = tensor.from_elements %c1_i64, %1174, %c4096_i64, %c1_i64 : tensor<4xi64> + %1175 = stablehlo.dynamic_reshape %1173, %from_elements_416 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1176 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_412, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_417 = tensor.dim %1176, %c1 : tensor<1x?x4096xi64> + %1177 = arith.index_cast %dim_417 : index to i64 + %from_elements_418 = tensor.from_elements %c1_i64, %1177, %c4096_i64, %c1_i64 : tensor<4xi64> + %1178 = stablehlo.dynamic_reshape %1176, %from_elements_418 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1179 = stablehlo.concatenate %1172, %1175, %1178, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1180 = "stablehlo.gather"(%787, %1179) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1181 = shape.shape_of %1180 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1182 = shape.num_elements %1181 : tensor<3xindex> -> index + %1183 = stablehlo.compute_reshape_shape %1182, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1184 = stablehlo.dynamic_reshape %1180, %1183 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1185 = stablehlo.dot %1184, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1186 = stablehlo.logistic %1185 : tensor + %1187 = shape.shape_of %1186 : tensor -> tensor<2xindex> + %1188 = shape.shape_of %1185 : tensor -> tensor<2xindex> + %1189 = shape.cstr_broadcastable %1187, %1188 : tensor<2xindex>, tensor<2xindex> + %1190 = shape.assuming %1189 -> (tensor) { + %19688 = shape.broadcast %1187, %1188 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1186, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1185, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1191 = shape.shape_of %1190 : tensor -> tensor<2xindex> + %1192 = shape.cstr_broadcastable %1191, %1188 : tensor<2xindex>, tensor<2xindex> + %1193 = shape.assuming %1192 -> (tensor) { + %19688 = shape.broadcast %1191, %1188 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1190, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1185, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1194 = stablehlo.dot %1193, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_419 = tensor.dim %1166, %c0 : tensor + %1195 = arith.index_cast %dim_419 : index to i64 + %from_elements_420 = tensor.from_elements %1195, %c1_i64 : tensor<2xi64> + %1196 = stablehlo.dynamic_reshape %1166, %from_elements_420 : (tensor, tensor<2xi64>) -> tensor + %dim_421 = tensor.dim %1163, %c0 : tensor + %1197 = arith.index_cast %dim_421 : index to i64 + %from_elements_422 = tensor.from_elements %1197, %c1_i64 : tensor<2xi64> + %1198 = stablehlo.dynamic_reshape %1163, %from_elements_422 : (tensor, tensor<2xi64>) -> tensor + %1199 = stablehlo.concatenate %1196, %1198, dim = 1 : (tensor, tensor) -> tensor + %1200 = "stablehlo.gather"(%816, %1199) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1201 = shape.shape_of %1194 : tensor -> tensor<2xindex> + %1202 = shape.shape_of %1200 : tensor -> tensor<2xindex> + %1203 = shape.cstr_broadcastable %1201, %1202 : tensor<2xindex>, tensor<2xindex> + %1204 = shape.assuming %1203 -> (tensor) { + %19688 = shape.broadcast %1201, %1202 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1194, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1200, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1205 = shape.shape_of %1204 : tensor -> tensor<2xindex> + %1206 = stablehlo.dynamic_broadcast_in_dim %1204, %1205, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1207 = stablehlo.dynamic_broadcast_in_dim %213, %1205, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1208 = stablehlo.multiply %1206, %1207 : tensor + %dim_423 = tensor.dim %1168, %c0 : tensor + %1209 = arith.index_cast %dim_423 : index to i64 + %dim_424 = tensor.dim %1204, %c0 : tensor + %1210 = arith.index_cast %dim_424 : index to i64 + %1211 = arith.maxsi %1209, %1210 : i64 + %1212 = arith.index_cast %1211 : i64 to index + %from_elements_425 = tensor.from_elements %1212, %c4096 : tensor<2xindex> + %1213 = stablehlo.dynamic_broadcast_in_dim %1168, %from_elements_425, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_426 = tensor.dim %1213, %c0 : tensor + %1214 = arith.index_cast %dim_426 : index to i64 + %from_elements_427 = tensor.from_elements %1214, %c4096_i64 : tensor<2xi64> + %1215 = stablehlo.real_dynamic_slice %1208, %c_22, %from_elements_427, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_428 = tensor.from_elements %1214, %c4096_i64, %c1_i64 : tensor<3xi64> + %1216 = stablehlo.dynamic_reshape %1213, %from_elements_428 : (tensor, tensor<3xi64>) -> tensor + %1217 = stablehlo.dynamic_iota %from_elements_428, dim = 1 : (tensor<3xi64>) -> tensor + %1218 = stablehlo.concatenate %1216, %1217, dim = 2 : (tensor, tensor) -> tensor + %1219 = "stablehlo.scatter"(%1156, %1218, %1215) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1220 = stablehlo.slice %776 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1221 = stablehlo.reshape %1220 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1222 = stablehlo.custom_call @byteir.non_zero(%1221) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_429 = tensor.dim %1222, %c0 : tensor + %1223 = arith.index_cast %dim_429 : index to i64 + %from_elements_430 = tensor.from_elements %1223, %c1_i64 : tensor<2xi64> + %1224 = stablehlo.real_dynamic_slice %1222, %c_22, %from_elements_430, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_431 = tensor.dim %1224, %c0 : tensor + %1225 = arith.index_cast %dim_431 : index to i64 + %from_elements_432 = tensor.from_elements %1225 : tensor<1xi64> + %1226 = stablehlo.dynamic_reshape %1224, %from_elements_432 : (tensor, tensor<1xi64>) -> tensor + %from_elements_433 = tensor.from_elements %1223, %c2_i64 : tensor<2xi64> + %1227 = stablehlo.real_dynamic_slice %1222, %c_24, %from_elements_433, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_434 = tensor.dim %1227, %c0 : tensor + %1228 = arith.index_cast %dim_434 : index to i64 + %from_elements_435 = tensor.from_elements %1228 : tensor<1xi64> + %1229 = stablehlo.dynamic_reshape %1227, %from_elements_435 : (tensor, tensor<1xi64>) -> tensor + %dim_436 = tensor.dim %1229, %c0 : tensor + %1230 = arith.index_cast %dim_436 : index to i64 + %from_elements_437 = tensor.from_elements %1230, %c1_i64 : tensor<2xi64> + %1231 = stablehlo.dynamic_reshape %1229, %from_elements_437 : (tensor, tensor<2xi64>) -> tensor + %dim_438 = tensor.dim %1231, %c0 : tensor + %1232 = arith.index_cast %dim_438 : index to i64 + %from_elements_439 = tensor.from_elements %c1_i64, %1232, %c4096_i64 : tensor<3xi64> + %1233 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_439, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_440 = tensor.dim %1233, %c1 : tensor<1x?x4096xi64> + %1234 = arith.index_cast %dim_440 : index to i64 + %from_elements_441 = tensor.from_elements %c1_i64, %1234, %c4096_i64, %c1_i64 : tensor<4xi64> + %1235 = stablehlo.dynamic_reshape %1233, %from_elements_441 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1236 = stablehlo.dynamic_broadcast_in_dim %1231, %from_elements_439, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_442 = tensor.dim %1236, %c1 : tensor<1x?x4096xi64> + %1237 = arith.index_cast %dim_442 : index to i64 + %from_elements_443 = tensor.from_elements %c1_i64, %1237, %c4096_i64, %c1_i64 : tensor<4xi64> + %1238 = stablehlo.dynamic_reshape %1236, %from_elements_443 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1239 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_439, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_444 = tensor.dim %1239, %c1 : tensor<1x?x4096xi64> + %1240 = arith.index_cast %dim_444 : index to i64 + %from_elements_445 = tensor.from_elements %c1_i64, %1240, %c4096_i64, %c1_i64 : tensor<4xi64> + %1241 = stablehlo.dynamic_reshape %1239, %from_elements_445 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1242 = stablehlo.concatenate %1235, %1238, %1241, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1243 = "stablehlo.gather"(%787, %1242) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1244 = shape.shape_of %1243 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1245 = shape.num_elements %1244 : tensor<3xindex> -> index + %1246 = stablehlo.compute_reshape_shape %1245, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1247 = stablehlo.dynamic_reshape %1243, %1246 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1248 = stablehlo.dot %1247, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1249 = stablehlo.logistic %1248 : tensor + %1250 = shape.shape_of %1249 : tensor -> tensor<2xindex> + %1251 = shape.shape_of %1248 : tensor -> tensor<2xindex> + %1252 = shape.cstr_broadcastable %1250, %1251 : tensor<2xindex>, tensor<2xindex> + %1253 = shape.assuming %1252 -> (tensor) { + %19688 = shape.broadcast %1250, %1251 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1249, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1248, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1254 = shape.shape_of %1253 : tensor -> tensor<2xindex> + %1255 = shape.cstr_broadcastable %1254, %1251 : tensor<2xindex>, tensor<2xindex> + %1256 = shape.assuming %1255 -> (tensor) { + %19688 = shape.broadcast %1254, %1251 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1253, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1248, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1257 = stablehlo.dot %1256, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_446 = tensor.dim %1229, %c0 : tensor + %1258 = arith.index_cast %dim_446 : index to i64 + %from_elements_447 = tensor.from_elements %1258, %c1_i64 : tensor<2xi64> + %1259 = stablehlo.dynamic_reshape %1229, %from_elements_447 : (tensor, tensor<2xi64>) -> tensor + %dim_448 = tensor.dim %1226, %c0 : tensor + %1260 = arith.index_cast %dim_448 : index to i64 + %from_elements_449 = tensor.from_elements %1260, %c1_i64 : tensor<2xi64> + %1261 = stablehlo.dynamic_reshape %1226, %from_elements_449 : (tensor, tensor<2xi64>) -> tensor + %1262 = stablehlo.concatenate %1259, %1261, dim = 1 : (tensor, tensor) -> tensor + %1263 = "stablehlo.gather"(%816, %1262) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1264 = shape.shape_of %1257 : tensor -> tensor<2xindex> + %1265 = shape.shape_of %1263 : tensor -> tensor<2xindex> + %1266 = shape.cstr_broadcastable %1264, %1265 : tensor<2xindex>, tensor<2xindex> + %1267 = shape.assuming %1266 -> (tensor) { + %19688 = shape.broadcast %1264, %1265 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1257, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1263, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1268 = shape.shape_of %1267 : tensor -> tensor<2xindex> + %1269 = stablehlo.dynamic_broadcast_in_dim %1267, %1268, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1270 = stablehlo.dynamic_broadcast_in_dim %213, %1268, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1271 = stablehlo.multiply %1269, %1270 : tensor + %dim_450 = tensor.dim %1231, %c0 : tensor + %1272 = arith.index_cast %dim_450 : index to i64 + %dim_451 = tensor.dim %1267, %c0 : tensor + %1273 = arith.index_cast %dim_451 : index to i64 + %1274 = arith.maxsi %1272, %1273 : i64 + %1275 = arith.index_cast %1274 : i64 to index + %from_elements_452 = tensor.from_elements %1275, %c4096 : tensor<2xindex> + %1276 = stablehlo.dynamic_broadcast_in_dim %1231, %from_elements_452, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_453 = tensor.dim %1276, %c0 : tensor + %1277 = arith.index_cast %dim_453 : index to i64 + %from_elements_454 = tensor.from_elements %1277, %c4096_i64 : tensor<2xi64> + %1278 = stablehlo.real_dynamic_slice %1271, %c_22, %from_elements_454, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_455 = tensor.from_elements %1277, %c4096_i64, %c1_i64 : tensor<3xi64> + %1279 = stablehlo.dynamic_reshape %1276, %from_elements_455 : (tensor, tensor<3xi64>) -> tensor + %1280 = stablehlo.dynamic_iota %from_elements_455, dim = 1 : (tensor<3xi64>) -> tensor + %1281 = stablehlo.concatenate %1279, %1280, dim = 2 : (tensor, tensor) -> tensor + %1282 = "stablehlo.scatter"(%1219, %1281, %1278) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1283 = stablehlo.reshape %1282 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %1284 = stablehlo.add %749, %1283 : tensor<3x1x4096xf32> + %1285 = stablehlo.broadcast_in_dim %1284, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %1286 = stablehlo.power %1285, %15 : tensor<3x1x4096xf32> + %1287 = stablehlo.reduce(%1286 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %1288 = stablehlo.reshape %1287 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %1289 = stablehlo.broadcast_in_dim %1288, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %1290 = stablehlo.divide %1289, %21 : tensor<3x1x1xf32> + %1291 = stablehlo.broadcast_in_dim %1290, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %1292 = stablehlo.add %1291, %25 : tensor<3x1x1xf32> + %1293 = stablehlo.rsqrt %1292 : tensor<3x1x1xf32> + %1294 = stablehlo.broadcast_in_dim %1293, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %1295 = stablehlo.multiply %1285, %1294 : tensor<3x1x4096xf32> + %1296 = stablehlo.broadcast_in_dim %1295, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %1297 = stablehlo.multiply %1296, %31 : tensor<3x1x4096xf32> + %1298 = stablehlo.reshape %1297 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %1299 = stablehlo.dot %1298, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %1300 = stablehlo.reshape %1299 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %1301 = stablehlo.dot %1298, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %1302 = stablehlo.reshape %1301 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %1303 = stablehlo.reshape %1300 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %1304 = stablehlo.transpose %1303, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %1305 = stablehlo.reshape %1302 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %1306 = stablehlo.transpose %1305, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %1307 = stablehlo.slice %arg4 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %1308 = stablehlo.slice %arg5 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %1309 = "stablehlo.gather"(%1307, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %1310 = stablehlo.reshape %1309 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %1311 = "stablehlo.gather"(%1308, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %1312 = stablehlo.reshape %1311 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %1313 = stablehlo.broadcast_in_dim %1304, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %1314 = stablehlo.broadcast_in_dim %1310, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %1315 = stablehlo.multiply %1313, %1314 : tensor<3x32x1x128xf32> + %1316 = stablehlo.slice %1304 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %1317 = stablehlo.slice %1304 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %1318 = stablehlo.negate %1317 : tensor<3x32x1x64xf32> + %1319 = stablehlo.concatenate %1318, %1316, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %1320 = stablehlo.broadcast_in_dim %1319, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %1321 = stablehlo.broadcast_in_dim %1312, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %1322 = stablehlo.multiply %1320, %1321 : tensor<3x32x1x128xf32> + %1323 = stablehlo.add %1315, %1322 : tensor<3x32x1x128xf32> + %1324 = stablehlo.broadcast_in_dim %1306, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %1325 = stablehlo.broadcast_in_dim %1310, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %1326 = stablehlo.multiply %1324, %1325 : tensor<3x8x1x128xf32> + %1327 = stablehlo.slice %1306 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %1328 = stablehlo.slice %1306 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %1329 = stablehlo.negate %1328 : tensor<3x8x1x64xf32> + %1330 = stablehlo.concatenate %1329, %1327, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %1331 = stablehlo.broadcast_in_dim %1330, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %1332 = stablehlo.broadcast_in_dim %1312, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %1333 = stablehlo.multiply %1331, %1332 : tensor<3x8x1x128xf32> + %1334 = stablehlo.add %1326, %1333 : tensor<3x8x1x128xf32> + %1335 = stablehlo.concatenate %arg69, %1334, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %1336 = stablehlo.concatenate %arg70, %1306, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %1337 = stablehlo.reshape %1335 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %1338 = stablehlo.broadcast_in_dim %1337, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %1339 = stablehlo.reshape %1338 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %1340 = stablehlo.reshape %1336 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %1341 = stablehlo.broadcast_in_dim %1340, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %1342 = stablehlo.reshape %1341 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %1343 = stablehlo.transpose %1339, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %1344 = stablehlo.reshape %1323 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %1345 = stablehlo.reshape %1343 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %1346 = stablehlo.broadcast_in_dim %1345, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %1347 = stablehlo.dot_general %1344, %1346, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %1348 = stablehlo.reshape %1347 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %1349 = stablehlo.broadcast_in_dim %1348, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %1350 = stablehlo.divide %1349, %89 : tensor<3x32x1x8xf32> + %1351 = stablehlo.custom_call @byteir.softmax(%1350) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %1352 = stablehlo.reshape %1351 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %1353 = stablehlo.reshape %1342 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %1354 = stablehlo.broadcast_in_dim %1353, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %1355 = stablehlo.dot_general %1352, %1354, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %1356 = stablehlo.reshape %1355 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %1357 = stablehlo.transpose %1356, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %1358 = stablehlo.reshape %1357 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %1359 = stablehlo.reshape %1358 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %1360 = stablehlo.dot %1359, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %1361 = stablehlo.reshape %1360 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %1362 = stablehlo.add %1284, %1361 : tensor<3x1x4096xf32> + %1363 = stablehlo.broadcast_in_dim %1362, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %1364 = stablehlo.power %1363, %15 : tensor<3x1x4096xf32> + %1365 = stablehlo.reduce(%1364 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %1366 = stablehlo.reshape %1365 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %1367 = stablehlo.broadcast_in_dim %1366, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %1368 = stablehlo.divide %1367, %21 : tensor<3x1x1xf32> + %1369 = stablehlo.broadcast_in_dim %1368, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %1370 = stablehlo.add %1369, %25 : tensor<3x1x1xf32> + %1371 = stablehlo.rsqrt %1370 : tensor<3x1x1xf32> + %1372 = stablehlo.broadcast_in_dim %1371, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %1373 = stablehlo.multiply %1363, %1372 : tensor<3x1x4096xf32> + %1374 = stablehlo.broadcast_in_dim %1373, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %1375 = stablehlo.multiply %1374, %31 : tensor<3x1x4096xf32> + %1376 = stablehlo.reshape %1375 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %1377 = stablehlo.dot %1376, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %1378 = stablehlo.custom_call @byteir.softmax(%1377) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %1379:2 = stablehlo.custom_call @byteir.top_k(%1378) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %1380 = stablehlo.reduce(%1379#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %1381 = stablehlo.reshape %1380 : (tensor<3xf32>) -> tensor<3x1xf32> + %1382 = stablehlo.broadcast_in_dim %1379#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %1383 = stablehlo.broadcast_in_dim %1381, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %1384 = stablehlo.divide %1382, %1383 : tensor<3x2xf32> + %1385 = stablehlo.reshape %1379#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %1386 = stablehlo.broadcast_in_dim %1385, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %1387 = stablehlo.compare EQ, %1386, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %1388 = stablehlo.convert %1387 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %1389 = stablehlo.transpose %1388, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %1390 = stablehlo.slice %1389 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1391 = stablehlo.reshape %1390 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1392 = stablehlo.custom_call @byteir.non_zero(%1391) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_456 = tensor.dim %1392, %c0 : tensor + %1393 = arith.index_cast %dim_456 : index to i64 + %from_elements_457 = tensor.from_elements %1393, %c1_i64 : tensor<2xi64> + %1394 = stablehlo.real_dynamic_slice %1392, %c_22, %from_elements_457, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_458 = tensor.dim %1394, %c0 : tensor + %1395 = arith.index_cast %dim_458 : index to i64 + %from_elements_459 = tensor.from_elements %1395 : tensor<1xi64> + %1396 = stablehlo.dynamic_reshape %1394, %from_elements_459 : (tensor, tensor<1xi64>) -> tensor + %from_elements_460 = tensor.from_elements %1393, %c2_i64 : tensor<2xi64> + %1397 = stablehlo.real_dynamic_slice %1392, %c_24, %from_elements_460, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_461 = tensor.dim %1397, %c0 : tensor + %1398 = arith.index_cast %dim_461 : index to i64 + %from_elements_462 = tensor.from_elements %1398 : tensor<1xi64> + %1399 = stablehlo.dynamic_reshape %1397, %from_elements_462 : (tensor, tensor<1xi64>) -> tensor + %1400 = stablehlo.reshape %1376 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_463 = tensor.dim %1399, %c0 : tensor + %1401 = arith.index_cast %dim_463 : index to i64 + %from_elements_464 = tensor.from_elements %1401, %c1_i64 : tensor<2xi64> + %1402 = stablehlo.dynamic_reshape %1399, %from_elements_464 : (tensor, tensor<2xi64>) -> tensor + %dim_465 = tensor.dim %1402, %c0 : tensor + %1403 = arith.index_cast %dim_465 : index to i64 + %from_elements_466 = tensor.from_elements %c1_i64, %1403, %c4096_i64 : tensor<3xi64> + %1404 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_466, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_467 = tensor.dim %1404, %c1 : tensor<1x?x4096xi64> + %1405 = arith.index_cast %dim_467 : index to i64 + %from_elements_468 = tensor.from_elements %c1_i64, %1405, %c4096_i64, %c1_i64 : tensor<4xi64> + %1406 = stablehlo.dynamic_reshape %1404, %from_elements_468 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1407 = stablehlo.dynamic_broadcast_in_dim %1402, %from_elements_466, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_469 = tensor.dim %1407, %c1 : tensor<1x?x4096xi64> + %1408 = arith.index_cast %dim_469 : index to i64 + %from_elements_470 = tensor.from_elements %c1_i64, %1408, %c4096_i64, %c1_i64 : tensor<4xi64> + %1409 = stablehlo.dynamic_reshape %1407, %from_elements_470 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1410 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_466, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_471 = tensor.dim %1410, %c1 : tensor<1x?x4096xi64> + %1411 = arith.index_cast %dim_471 : index to i64 + %from_elements_472 = tensor.from_elements %c1_i64, %1411, %c4096_i64, %c1_i64 : tensor<4xi64> + %1412 = stablehlo.dynamic_reshape %1410, %from_elements_472 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1413 = stablehlo.concatenate %1406, %1409, %1412, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1414 = "stablehlo.gather"(%1400, %1413) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1415 = shape.shape_of %1414 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1416 = shape.num_elements %1415 : tensor<3xindex> -> index + %1417 = stablehlo.compute_reshape_shape %1416, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1418 = stablehlo.dynamic_reshape %1414, %1417 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1419 = stablehlo.dot %1418, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1420 = stablehlo.logistic %1419 : tensor + %1421 = shape.shape_of %1420 : tensor -> tensor<2xindex> + %1422 = shape.shape_of %1419 : tensor -> tensor<2xindex> + %1423 = shape.cstr_broadcastable %1421, %1422 : tensor<2xindex>, tensor<2xindex> + %1424 = shape.assuming %1423 -> (tensor) { + %19688 = shape.broadcast %1421, %1422 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1420, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1419, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1425 = shape.shape_of %1424 : tensor -> tensor<2xindex> + %1426 = shape.cstr_broadcastable %1425, %1422 : tensor<2xindex>, tensor<2xindex> + %1427 = shape.assuming %1426 -> (tensor) { + %19688 = shape.broadcast %1425, %1422 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1424, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1419, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1428 = stablehlo.dot %1427, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %1429 = stablehlo.reshape %1384 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_473 = tensor.dim %1399, %c0 : tensor + %1430 = arith.index_cast %dim_473 : index to i64 + %from_elements_474 = tensor.from_elements %1430, %c1_i64 : tensor<2xi64> + %1431 = stablehlo.dynamic_reshape %1399, %from_elements_474 : (tensor, tensor<2xi64>) -> tensor + %dim_475 = tensor.dim %1396, %c0 : tensor + %1432 = arith.index_cast %dim_475 : index to i64 + %from_elements_476 = tensor.from_elements %1432, %c1_i64 : tensor<2xi64> + %1433 = stablehlo.dynamic_reshape %1396, %from_elements_476 : (tensor, tensor<2xi64>) -> tensor + %1434 = stablehlo.concatenate %1431, %1433, dim = 1 : (tensor, tensor) -> tensor + %1435 = "stablehlo.gather"(%1429, %1434) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1436 = shape.shape_of %1428 : tensor -> tensor<2xindex> + %1437 = shape.shape_of %1435 : tensor -> tensor<2xindex> + %1438 = shape.cstr_broadcastable %1436, %1437 : tensor<2xindex>, tensor<2xindex> + %1439 = shape.assuming %1438 -> (tensor) { + %19688 = shape.broadcast %1436, %1437 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1428, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1435, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1440 = shape.shape_of %1439 : tensor -> tensor<2xindex> + %1441 = stablehlo.dynamic_broadcast_in_dim %1439, %1440, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1442 = stablehlo.dynamic_broadcast_in_dim %213, %1440, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1443 = stablehlo.multiply %1441, %1442 : tensor + %dim_477 = tensor.dim %1402, %c0 : tensor + %1444 = arith.index_cast %dim_477 : index to i64 + %dim_478 = tensor.dim %1439, %c0 : tensor + %1445 = arith.index_cast %dim_478 : index to i64 + %1446 = arith.maxsi %1444, %1445 : i64 + %1447 = arith.index_cast %1446 : i64 to index + %from_elements_479 = tensor.from_elements %1447, %c4096 : tensor<2xindex> + %1448 = stablehlo.dynamic_broadcast_in_dim %1402, %from_elements_479, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_480 = tensor.dim %1448, %c0 : tensor + %1449 = arith.index_cast %dim_480 : index to i64 + %from_elements_481 = tensor.from_elements %1449, %c4096_i64 : tensor<2xi64> + %1450 = stablehlo.real_dynamic_slice %1443, %c_22, %from_elements_481, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_482 = tensor.from_elements %1449, %c4096_i64, %c1_i64 : tensor<3xi64> + %1451 = stablehlo.dynamic_reshape %1448, %from_elements_482 : (tensor, tensor<3xi64>) -> tensor + %1452 = stablehlo.dynamic_iota %from_elements_482, dim = 1 : (tensor<3xi64>) -> tensor + %1453 = stablehlo.concatenate %1451, %1452, dim = 2 : (tensor, tensor) -> tensor + %1454 = "stablehlo.scatter"(%cst_2, %1453, %1450) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1455 = stablehlo.slice %1389 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1456 = stablehlo.reshape %1455 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1457 = stablehlo.custom_call @byteir.non_zero(%1456) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_483 = tensor.dim %1457, %c0 : tensor + %1458 = arith.index_cast %dim_483 : index to i64 + %from_elements_484 = tensor.from_elements %1458, %c1_i64 : tensor<2xi64> + %1459 = stablehlo.real_dynamic_slice %1457, %c_22, %from_elements_484, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_485 = tensor.dim %1459, %c0 : tensor + %1460 = arith.index_cast %dim_485 : index to i64 + %from_elements_486 = tensor.from_elements %1460 : tensor<1xi64> + %1461 = stablehlo.dynamic_reshape %1459, %from_elements_486 : (tensor, tensor<1xi64>) -> tensor + %from_elements_487 = tensor.from_elements %1458, %c2_i64 : tensor<2xi64> + %1462 = stablehlo.real_dynamic_slice %1457, %c_24, %from_elements_487, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_488 = tensor.dim %1462, %c0 : tensor + %1463 = arith.index_cast %dim_488 : index to i64 + %from_elements_489 = tensor.from_elements %1463 : tensor<1xi64> + %1464 = stablehlo.dynamic_reshape %1462, %from_elements_489 : (tensor, tensor<1xi64>) -> tensor + %dim_490 = tensor.dim %1464, %c0 : tensor + %1465 = arith.index_cast %dim_490 : index to i64 + %from_elements_491 = tensor.from_elements %1465, %c1_i64 : tensor<2xi64> + %1466 = stablehlo.dynamic_reshape %1464, %from_elements_491 : (tensor, tensor<2xi64>) -> tensor + %dim_492 = tensor.dim %1466, %c0 : tensor + %1467 = arith.index_cast %dim_492 : index to i64 + %from_elements_493 = tensor.from_elements %c1_i64, %1467, %c4096_i64 : tensor<3xi64> + %1468 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_493, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_494 = tensor.dim %1468, %c1 : tensor<1x?x4096xi64> + %1469 = arith.index_cast %dim_494 : index to i64 + %from_elements_495 = tensor.from_elements %c1_i64, %1469, %c4096_i64, %c1_i64 : tensor<4xi64> + %1470 = stablehlo.dynamic_reshape %1468, %from_elements_495 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1471 = stablehlo.dynamic_broadcast_in_dim %1466, %from_elements_493, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_496 = tensor.dim %1471, %c1 : tensor<1x?x4096xi64> + %1472 = arith.index_cast %dim_496 : index to i64 + %from_elements_497 = tensor.from_elements %c1_i64, %1472, %c4096_i64, %c1_i64 : tensor<4xi64> + %1473 = stablehlo.dynamic_reshape %1471, %from_elements_497 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1474 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_493, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_498 = tensor.dim %1474, %c1 : tensor<1x?x4096xi64> + %1475 = arith.index_cast %dim_498 : index to i64 + %from_elements_499 = tensor.from_elements %c1_i64, %1475, %c4096_i64, %c1_i64 : tensor<4xi64> + %1476 = stablehlo.dynamic_reshape %1474, %from_elements_499 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1477 = stablehlo.concatenate %1470, %1473, %1476, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1478 = "stablehlo.gather"(%1400, %1477) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1479 = shape.shape_of %1478 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1480 = shape.num_elements %1479 : tensor<3xindex> -> index + %1481 = stablehlo.compute_reshape_shape %1480, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1482 = stablehlo.dynamic_reshape %1478, %1481 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1483 = stablehlo.dot %1482, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1484 = stablehlo.logistic %1483 : tensor + %1485 = shape.shape_of %1484 : tensor -> tensor<2xindex> + %1486 = shape.shape_of %1483 : tensor -> tensor<2xindex> + %1487 = shape.cstr_broadcastable %1485, %1486 : tensor<2xindex>, tensor<2xindex> + %1488 = shape.assuming %1487 -> (tensor) { + %19688 = shape.broadcast %1485, %1486 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1484, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1483, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1489 = shape.shape_of %1488 : tensor -> tensor<2xindex> + %1490 = shape.cstr_broadcastable %1489, %1486 : tensor<2xindex>, tensor<2xindex> + %1491 = shape.assuming %1490 -> (tensor) { + %19688 = shape.broadcast %1489, %1486 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1488, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1483, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1492 = stablehlo.dot %1491, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_500 = tensor.dim %1464, %c0 : tensor + %1493 = arith.index_cast %dim_500 : index to i64 + %from_elements_501 = tensor.from_elements %1493, %c1_i64 : tensor<2xi64> + %1494 = stablehlo.dynamic_reshape %1464, %from_elements_501 : (tensor, tensor<2xi64>) -> tensor + %dim_502 = tensor.dim %1461, %c0 : tensor + %1495 = arith.index_cast %dim_502 : index to i64 + %from_elements_503 = tensor.from_elements %1495, %c1_i64 : tensor<2xi64> + %1496 = stablehlo.dynamic_reshape %1461, %from_elements_503 : (tensor, tensor<2xi64>) -> tensor + %1497 = stablehlo.concatenate %1494, %1496, dim = 1 : (tensor, tensor) -> tensor + %1498 = "stablehlo.gather"(%1429, %1497) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1499 = shape.shape_of %1492 : tensor -> tensor<2xindex> + %1500 = shape.shape_of %1498 : tensor -> tensor<2xindex> + %1501 = shape.cstr_broadcastable %1499, %1500 : tensor<2xindex>, tensor<2xindex> + %1502 = shape.assuming %1501 -> (tensor) { + %19688 = shape.broadcast %1499, %1500 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1492, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1498, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1503 = shape.shape_of %1502 : tensor -> tensor<2xindex> + %1504 = stablehlo.dynamic_broadcast_in_dim %1502, %1503, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1505 = stablehlo.dynamic_broadcast_in_dim %213, %1503, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1506 = stablehlo.multiply %1504, %1505 : tensor + %dim_504 = tensor.dim %1466, %c0 : tensor + %1507 = arith.index_cast %dim_504 : index to i64 + %dim_505 = tensor.dim %1502, %c0 : tensor + %1508 = arith.index_cast %dim_505 : index to i64 + %1509 = arith.maxsi %1507, %1508 : i64 + %1510 = arith.index_cast %1509 : i64 to index + %from_elements_506 = tensor.from_elements %1510, %c4096 : tensor<2xindex> + %1511 = stablehlo.dynamic_broadcast_in_dim %1466, %from_elements_506, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_507 = tensor.dim %1511, %c0 : tensor + %1512 = arith.index_cast %dim_507 : index to i64 + %from_elements_508 = tensor.from_elements %1512, %c4096_i64 : tensor<2xi64> + %1513 = stablehlo.real_dynamic_slice %1506, %c_22, %from_elements_508, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_509 = tensor.from_elements %1512, %c4096_i64, %c1_i64 : tensor<3xi64> + %1514 = stablehlo.dynamic_reshape %1511, %from_elements_509 : (tensor, tensor<3xi64>) -> tensor + %1515 = stablehlo.dynamic_iota %from_elements_509, dim = 1 : (tensor<3xi64>) -> tensor + %1516 = stablehlo.concatenate %1514, %1515, dim = 2 : (tensor, tensor) -> tensor + %1517 = "stablehlo.scatter"(%1454, %1516, %1513) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1518 = stablehlo.slice %1389 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1519 = stablehlo.reshape %1518 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1520 = stablehlo.custom_call @byteir.non_zero(%1519) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_510 = tensor.dim %1520, %c0 : tensor + %1521 = arith.index_cast %dim_510 : index to i64 + %from_elements_511 = tensor.from_elements %1521, %c1_i64 : tensor<2xi64> + %1522 = stablehlo.real_dynamic_slice %1520, %c_22, %from_elements_511, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_512 = tensor.dim %1522, %c0 : tensor + %1523 = arith.index_cast %dim_512 : index to i64 + %from_elements_513 = tensor.from_elements %1523 : tensor<1xi64> + %1524 = stablehlo.dynamic_reshape %1522, %from_elements_513 : (tensor, tensor<1xi64>) -> tensor + %from_elements_514 = tensor.from_elements %1521, %c2_i64 : tensor<2xi64> + %1525 = stablehlo.real_dynamic_slice %1520, %c_24, %from_elements_514, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_515 = tensor.dim %1525, %c0 : tensor + %1526 = arith.index_cast %dim_515 : index to i64 + %from_elements_516 = tensor.from_elements %1526 : tensor<1xi64> + %1527 = stablehlo.dynamic_reshape %1525, %from_elements_516 : (tensor, tensor<1xi64>) -> tensor + %dim_517 = tensor.dim %1527, %c0 : tensor + %1528 = arith.index_cast %dim_517 : index to i64 + %from_elements_518 = tensor.from_elements %1528, %c1_i64 : tensor<2xi64> + %1529 = stablehlo.dynamic_reshape %1527, %from_elements_518 : (tensor, tensor<2xi64>) -> tensor + %dim_519 = tensor.dim %1529, %c0 : tensor + %1530 = arith.index_cast %dim_519 : index to i64 + %from_elements_520 = tensor.from_elements %c1_i64, %1530, %c4096_i64 : tensor<3xi64> + %1531 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_520, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_521 = tensor.dim %1531, %c1 : tensor<1x?x4096xi64> + %1532 = arith.index_cast %dim_521 : index to i64 + %from_elements_522 = tensor.from_elements %c1_i64, %1532, %c4096_i64, %c1_i64 : tensor<4xi64> + %1533 = stablehlo.dynamic_reshape %1531, %from_elements_522 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1534 = stablehlo.dynamic_broadcast_in_dim %1529, %from_elements_520, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_523 = tensor.dim %1534, %c1 : tensor<1x?x4096xi64> + %1535 = arith.index_cast %dim_523 : index to i64 + %from_elements_524 = tensor.from_elements %c1_i64, %1535, %c4096_i64, %c1_i64 : tensor<4xi64> + %1536 = stablehlo.dynamic_reshape %1534, %from_elements_524 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1537 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_520, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_525 = tensor.dim %1537, %c1 : tensor<1x?x4096xi64> + %1538 = arith.index_cast %dim_525 : index to i64 + %from_elements_526 = tensor.from_elements %c1_i64, %1538, %c4096_i64, %c1_i64 : tensor<4xi64> + %1539 = stablehlo.dynamic_reshape %1537, %from_elements_526 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1540 = stablehlo.concatenate %1533, %1536, %1539, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1541 = "stablehlo.gather"(%1400, %1540) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1542 = shape.shape_of %1541 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1543 = shape.num_elements %1542 : tensor<3xindex> -> index + %1544 = stablehlo.compute_reshape_shape %1543, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1545 = stablehlo.dynamic_reshape %1541, %1544 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1546 = stablehlo.dot %1545, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1547 = stablehlo.logistic %1546 : tensor + %1548 = shape.shape_of %1547 : tensor -> tensor<2xindex> + %1549 = shape.shape_of %1546 : tensor -> tensor<2xindex> + %1550 = shape.cstr_broadcastable %1548, %1549 : tensor<2xindex>, tensor<2xindex> + %1551 = shape.assuming %1550 -> (tensor) { + %19688 = shape.broadcast %1548, %1549 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1547, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1546, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1552 = shape.shape_of %1551 : tensor -> tensor<2xindex> + %1553 = shape.cstr_broadcastable %1552, %1549 : tensor<2xindex>, tensor<2xindex> + %1554 = shape.assuming %1553 -> (tensor) { + %19688 = shape.broadcast %1552, %1549 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1551, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1546, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1555 = stablehlo.dot %1554, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_527 = tensor.dim %1527, %c0 : tensor + %1556 = arith.index_cast %dim_527 : index to i64 + %from_elements_528 = tensor.from_elements %1556, %c1_i64 : tensor<2xi64> + %1557 = stablehlo.dynamic_reshape %1527, %from_elements_528 : (tensor, tensor<2xi64>) -> tensor + %dim_529 = tensor.dim %1524, %c0 : tensor + %1558 = arith.index_cast %dim_529 : index to i64 + %from_elements_530 = tensor.from_elements %1558, %c1_i64 : tensor<2xi64> + %1559 = stablehlo.dynamic_reshape %1524, %from_elements_530 : (tensor, tensor<2xi64>) -> tensor + %1560 = stablehlo.concatenate %1557, %1559, dim = 1 : (tensor, tensor) -> tensor + %1561 = "stablehlo.gather"(%1429, %1560) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1562 = shape.shape_of %1555 : tensor -> tensor<2xindex> + %1563 = shape.shape_of %1561 : tensor -> tensor<2xindex> + %1564 = shape.cstr_broadcastable %1562, %1563 : tensor<2xindex>, tensor<2xindex> + %1565 = shape.assuming %1564 -> (tensor) { + %19688 = shape.broadcast %1562, %1563 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1555, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1561, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1566 = shape.shape_of %1565 : tensor -> tensor<2xindex> + %1567 = stablehlo.dynamic_broadcast_in_dim %1565, %1566, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1568 = stablehlo.dynamic_broadcast_in_dim %213, %1566, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1569 = stablehlo.multiply %1567, %1568 : tensor + %dim_531 = tensor.dim %1529, %c0 : tensor + %1570 = arith.index_cast %dim_531 : index to i64 + %dim_532 = tensor.dim %1565, %c0 : tensor + %1571 = arith.index_cast %dim_532 : index to i64 + %1572 = arith.maxsi %1570, %1571 : i64 + %1573 = arith.index_cast %1572 : i64 to index + %from_elements_533 = tensor.from_elements %1573, %c4096 : tensor<2xindex> + %1574 = stablehlo.dynamic_broadcast_in_dim %1529, %from_elements_533, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_534 = tensor.dim %1574, %c0 : tensor + %1575 = arith.index_cast %dim_534 : index to i64 + %from_elements_535 = tensor.from_elements %1575, %c4096_i64 : tensor<2xi64> + %1576 = stablehlo.real_dynamic_slice %1569, %c_22, %from_elements_535, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_536 = tensor.from_elements %1575, %c4096_i64, %c1_i64 : tensor<3xi64> + %1577 = stablehlo.dynamic_reshape %1574, %from_elements_536 : (tensor, tensor<3xi64>) -> tensor + %1578 = stablehlo.dynamic_iota %from_elements_536, dim = 1 : (tensor<3xi64>) -> tensor + %1579 = stablehlo.concatenate %1577, %1578, dim = 2 : (tensor, tensor) -> tensor + %1580 = "stablehlo.scatter"(%1517, %1579, %1576) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1581 = stablehlo.slice %1389 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1582 = stablehlo.reshape %1581 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1583 = stablehlo.custom_call @byteir.non_zero(%1582) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_537 = tensor.dim %1583, %c0 : tensor + %1584 = arith.index_cast %dim_537 : index to i64 + %from_elements_538 = tensor.from_elements %1584, %c1_i64 : tensor<2xi64> + %1585 = stablehlo.real_dynamic_slice %1583, %c_22, %from_elements_538, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_539 = tensor.dim %1585, %c0 : tensor + %1586 = arith.index_cast %dim_539 : index to i64 + %from_elements_540 = tensor.from_elements %1586 : tensor<1xi64> + %1587 = stablehlo.dynamic_reshape %1585, %from_elements_540 : (tensor, tensor<1xi64>) -> tensor + %from_elements_541 = tensor.from_elements %1584, %c2_i64 : tensor<2xi64> + %1588 = stablehlo.real_dynamic_slice %1583, %c_24, %from_elements_541, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_542 = tensor.dim %1588, %c0 : tensor + %1589 = arith.index_cast %dim_542 : index to i64 + %from_elements_543 = tensor.from_elements %1589 : tensor<1xi64> + %1590 = stablehlo.dynamic_reshape %1588, %from_elements_543 : (tensor, tensor<1xi64>) -> tensor + %dim_544 = tensor.dim %1590, %c0 : tensor + %1591 = arith.index_cast %dim_544 : index to i64 + %from_elements_545 = tensor.from_elements %1591, %c1_i64 : tensor<2xi64> + %1592 = stablehlo.dynamic_reshape %1590, %from_elements_545 : (tensor, tensor<2xi64>) -> tensor + %dim_546 = tensor.dim %1592, %c0 : tensor + %1593 = arith.index_cast %dim_546 : index to i64 + %from_elements_547 = tensor.from_elements %c1_i64, %1593, %c4096_i64 : tensor<3xi64> + %1594 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_547, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_548 = tensor.dim %1594, %c1 : tensor<1x?x4096xi64> + %1595 = arith.index_cast %dim_548 : index to i64 + %from_elements_549 = tensor.from_elements %c1_i64, %1595, %c4096_i64, %c1_i64 : tensor<4xi64> + %1596 = stablehlo.dynamic_reshape %1594, %from_elements_549 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1597 = stablehlo.dynamic_broadcast_in_dim %1592, %from_elements_547, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_550 = tensor.dim %1597, %c1 : tensor<1x?x4096xi64> + %1598 = arith.index_cast %dim_550 : index to i64 + %from_elements_551 = tensor.from_elements %c1_i64, %1598, %c4096_i64, %c1_i64 : tensor<4xi64> + %1599 = stablehlo.dynamic_reshape %1597, %from_elements_551 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1600 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_547, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_552 = tensor.dim %1600, %c1 : tensor<1x?x4096xi64> + %1601 = arith.index_cast %dim_552 : index to i64 + %from_elements_553 = tensor.from_elements %c1_i64, %1601, %c4096_i64, %c1_i64 : tensor<4xi64> + %1602 = stablehlo.dynamic_reshape %1600, %from_elements_553 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1603 = stablehlo.concatenate %1596, %1599, %1602, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1604 = "stablehlo.gather"(%1400, %1603) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1605 = shape.shape_of %1604 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1606 = shape.num_elements %1605 : tensor<3xindex> -> index + %1607 = stablehlo.compute_reshape_shape %1606, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1608 = stablehlo.dynamic_reshape %1604, %1607 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1609 = stablehlo.dot %1608, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1610 = stablehlo.logistic %1609 : tensor + %1611 = shape.shape_of %1610 : tensor -> tensor<2xindex> + %1612 = shape.shape_of %1609 : tensor -> tensor<2xindex> + %1613 = shape.cstr_broadcastable %1611, %1612 : tensor<2xindex>, tensor<2xindex> + %1614 = shape.assuming %1613 -> (tensor) { + %19688 = shape.broadcast %1611, %1612 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1610, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1609, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1615 = shape.shape_of %1614 : tensor -> tensor<2xindex> + %1616 = shape.cstr_broadcastable %1615, %1612 : tensor<2xindex>, tensor<2xindex> + %1617 = shape.assuming %1616 -> (tensor) { + %19688 = shape.broadcast %1615, %1612 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1614, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1609, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1618 = stablehlo.dot %1617, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_554 = tensor.dim %1590, %c0 : tensor + %1619 = arith.index_cast %dim_554 : index to i64 + %from_elements_555 = tensor.from_elements %1619, %c1_i64 : tensor<2xi64> + %1620 = stablehlo.dynamic_reshape %1590, %from_elements_555 : (tensor, tensor<2xi64>) -> tensor + %dim_556 = tensor.dim %1587, %c0 : tensor + %1621 = arith.index_cast %dim_556 : index to i64 + %from_elements_557 = tensor.from_elements %1621, %c1_i64 : tensor<2xi64> + %1622 = stablehlo.dynamic_reshape %1587, %from_elements_557 : (tensor, tensor<2xi64>) -> tensor + %1623 = stablehlo.concatenate %1620, %1622, dim = 1 : (tensor, tensor) -> tensor + %1624 = "stablehlo.gather"(%1429, %1623) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1625 = shape.shape_of %1618 : tensor -> tensor<2xindex> + %1626 = shape.shape_of %1624 : tensor -> tensor<2xindex> + %1627 = shape.cstr_broadcastable %1625, %1626 : tensor<2xindex>, tensor<2xindex> + %1628 = shape.assuming %1627 -> (tensor) { + %19688 = shape.broadcast %1625, %1626 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1618, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1624, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1629 = shape.shape_of %1628 : tensor -> tensor<2xindex> + %1630 = stablehlo.dynamic_broadcast_in_dim %1628, %1629, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1631 = stablehlo.dynamic_broadcast_in_dim %213, %1629, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1632 = stablehlo.multiply %1630, %1631 : tensor + %dim_558 = tensor.dim %1592, %c0 : tensor + %1633 = arith.index_cast %dim_558 : index to i64 + %dim_559 = tensor.dim %1628, %c0 : tensor + %1634 = arith.index_cast %dim_559 : index to i64 + %1635 = arith.maxsi %1633, %1634 : i64 + %1636 = arith.index_cast %1635 : i64 to index + %from_elements_560 = tensor.from_elements %1636, %c4096 : tensor<2xindex> + %1637 = stablehlo.dynamic_broadcast_in_dim %1592, %from_elements_560, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_561 = tensor.dim %1637, %c0 : tensor + %1638 = arith.index_cast %dim_561 : index to i64 + %from_elements_562 = tensor.from_elements %1638, %c4096_i64 : tensor<2xi64> + %1639 = stablehlo.real_dynamic_slice %1632, %c_22, %from_elements_562, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_563 = tensor.from_elements %1638, %c4096_i64, %c1_i64 : tensor<3xi64> + %1640 = stablehlo.dynamic_reshape %1637, %from_elements_563 : (tensor, tensor<3xi64>) -> tensor + %1641 = stablehlo.dynamic_iota %from_elements_563, dim = 1 : (tensor<3xi64>) -> tensor + %1642 = stablehlo.concatenate %1640, %1641, dim = 2 : (tensor, tensor) -> tensor + %1643 = "stablehlo.scatter"(%1580, %1642, %1639) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1644 = stablehlo.slice %1389 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1645 = stablehlo.reshape %1644 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1646 = stablehlo.custom_call @byteir.non_zero(%1645) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_564 = tensor.dim %1646, %c0 : tensor + %1647 = arith.index_cast %dim_564 : index to i64 + %from_elements_565 = tensor.from_elements %1647, %c1_i64 : tensor<2xi64> + %1648 = stablehlo.real_dynamic_slice %1646, %c_22, %from_elements_565, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_566 = tensor.dim %1648, %c0 : tensor + %1649 = arith.index_cast %dim_566 : index to i64 + %from_elements_567 = tensor.from_elements %1649 : tensor<1xi64> + %1650 = stablehlo.dynamic_reshape %1648, %from_elements_567 : (tensor, tensor<1xi64>) -> tensor + %from_elements_568 = tensor.from_elements %1647, %c2_i64 : tensor<2xi64> + %1651 = stablehlo.real_dynamic_slice %1646, %c_24, %from_elements_568, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_569 = tensor.dim %1651, %c0 : tensor + %1652 = arith.index_cast %dim_569 : index to i64 + %from_elements_570 = tensor.from_elements %1652 : tensor<1xi64> + %1653 = stablehlo.dynamic_reshape %1651, %from_elements_570 : (tensor, tensor<1xi64>) -> tensor + %dim_571 = tensor.dim %1653, %c0 : tensor + %1654 = arith.index_cast %dim_571 : index to i64 + %from_elements_572 = tensor.from_elements %1654, %c1_i64 : tensor<2xi64> + %1655 = stablehlo.dynamic_reshape %1653, %from_elements_572 : (tensor, tensor<2xi64>) -> tensor + %dim_573 = tensor.dim %1655, %c0 : tensor + %1656 = arith.index_cast %dim_573 : index to i64 + %from_elements_574 = tensor.from_elements %c1_i64, %1656, %c4096_i64 : tensor<3xi64> + %1657 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_574, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_575 = tensor.dim %1657, %c1 : tensor<1x?x4096xi64> + %1658 = arith.index_cast %dim_575 : index to i64 + %from_elements_576 = tensor.from_elements %c1_i64, %1658, %c4096_i64, %c1_i64 : tensor<4xi64> + %1659 = stablehlo.dynamic_reshape %1657, %from_elements_576 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1660 = stablehlo.dynamic_broadcast_in_dim %1655, %from_elements_574, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_577 = tensor.dim %1660, %c1 : tensor<1x?x4096xi64> + %1661 = arith.index_cast %dim_577 : index to i64 + %from_elements_578 = tensor.from_elements %c1_i64, %1661, %c4096_i64, %c1_i64 : tensor<4xi64> + %1662 = stablehlo.dynamic_reshape %1660, %from_elements_578 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1663 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_574, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_579 = tensor.dim %1663, %c1 : tensor<1x?x4096xi64> + %1664 = arith.index_cast %dim_579 : index to i64 + %from_elements_580 = tensor.from_elements %c1_i64, %1664, %c4096_i64, %c1_i64 : tensor<4xi64> + %1665 = stablehlo.dynamic_reshape %1663, %from_elements_580 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1666 = stablehlo.concatenate %1659, %1662, %1665, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1667 = "stablehlo.gather"(%1400, %1666) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1668 = shape.shape_of %1667 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1669 = shape.num_elements %1668 : tensor<3xindex> -> index + %1670 = stablehlo.compute_reshape_shape %1669, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1671 = stablehlo.dynamic_reshape %1667, %1670 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1672 = stablehlo.dot %1671, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1673 = stablehlo.logistic %1672 : tensor + %1674 = shape.shape_of %1673 : tensor -> tensor<2xindex> + %1675 = shape.shape_of %1672 : tensor -> tensor<2xindex> + %1676 = shape.cstr_broadcastable %1674, %1675 : tensor<2xindex>, tensor<2xindex> + %1677 = shape.assuming %1676 -> (tensor) { + %19688 = shape.broadcast %1674, %1675 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1673, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1672, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1678 = shape.shape_of %1677 : tensor -> tensor<2xindex> + %1679 = shape.cstr_broadcastable %1678, %1675 : tensor<2xindex>, tensor<2xindex> + %1680 = shape.assuming %1679 -> (tensor) { + %19688 = shape.broadcast %1678, %1675 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1677, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1672, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1681 = stablehlo.dot %1680, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_581 = tensor.dim %1653, %c0 : tensor + %1682 = arith.index_cast %dim_581 : index to i64 + %from_elements_582 = tensor.from_elements %1682, %c1_i64 : tensor<2xi64> + %1683 = stablehlo.dynamic_reshape %1653, %from_elements_582 : (tensor, tensor<2xi64>) -> tensor + %dim_583 = tensor.dim %1650, %c0 : tensor + %1684 = arith.index_cast %dim_583 : index to i64 + %from_elements_584 = tensor.from_elements %1684, %c1_i64 : tensor<2xi64> + %1685 = stablehlo.dynamic_reshape %1650, %from_elements_584 : (tensor, tensor<2xi64>) -> tensor + %1686 = stablehlo.concatenate %1683, %1685, dim = 1 : (tensor, tensor) -> tensor + %1687 = "stablehlo.gather"(%1429, %1686) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1688 = shape.shape_of %1681 : tensor -> tensor<2xindex> + %1689 = shape.shape_of %1687 : tensor -> tensor<2xindex> + %1690 = shape.cstr_broadcastable %1688, %1689 : tensor<2xindex>, tensor<2xindex> + %1691 = shape.assuming %1690 -> (tensor) { + %19688 = shape.broadcast %1688, %1689 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1681, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1687, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1692 = shape.shape_of %1691 : tensor -> tensor<2xindex> + %1693 = stablehlo.dynamic_broadcast_in_dim %1691, %1692, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1694 = stablehlo.dynamic_broadcast_in_dim %213, %1692, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1695 = stablehlo.multiply %1693, %1694 : tensor + %dim_585 = tensor.dim %1655, %c0 : tensor + %1696 = arith.index_cast %dim_585 : index to i64 + %dim_586 = tensor.dim %1691, %c0 : tensor + %1697 = arith.index_cast %dim_586 : index to i64 + %1698 = arith.maxsi %1696, %1697 : i64 + %1699 = arith.index_cast %1698 : i64 to index + %from_elements_587 = tensor.from_elements %1699, %c4096 : tensor<2xindex> + %1700 = stablehlo.dynamic_broadcast_in_dim %1655, %from_elements_587, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_588 = tensor.dim %1700, %c0 : tensor + %1701 = arith.index_cast %dim_588 : index to i64 + %from_elements_589 = tensor.from_elements %1701, %c4096_i64 : tensor<2xi64> + %1702 = stablehlo.real_dynamic_slice %1695, %c_22, %from_elements_589, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_590 = tensor.from_elements %1701, %c4096_i64, %c1_i64 : tensor<3xi64> + %1703 = stablehlo.dynamic_reshape %1700, %from_elements_590 : (tensor, tensor<3xi64>) -> tensor + %1704 = stablehlo.dynamic_iota %from_elements_590, dim = 1 : (tensor<3xi64>) -> tensor + %1705 = stablehlo.concatenate %1703, %1704, dim = 2 : (tensor, tensor) -> tensor + %1706 = "stablehlo.scatter"(%1643, %1705, %1702) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1707 = stablehlo.slice %1389 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1708 = stablehlo.reshape %1707 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1709 = stablehlo.custom_call @byteir.non_zero(%1708) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_591 = tensor.dim %1709, %c0 : tensor + %1710 = arith.index_cast %dim_591 : index to i64 + %from_elements_592 = tensor.from_elements %1710, %c1_i64 : tensor<2xi64> + %1711 = stablehlo.real_dynamic_slice %1709, %c_22, %from_elements_592, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_593 = tensor.dim %1711, %c0 : tensor + %1712 = arith.index_cast %dim_593 : index to i64 + %from_elements_594 = tensor.from_elements %1712 : tensor<1xi64> + %1713 = stablehlo.dynamic_reshape %1711, %from_elements_594 : (tensor, tensor<1xi64>) -> tensor + %from_elements_595 = tensor.from_elements %1710, %c2_i64 : tensor<2xi64> + %1714 = stablehlo.real_dynamic_slice %1709, %c_24, %from_elements_595, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_596 = tensor.dim %1714, %c0 : tensor + %1715 = arith.index_cast %dim_596 : index to i64 + %from_elements_597 = tensor.from_elements %1715 : tensor<1xi64> + %1716 = stablehlo.dynamic_reshape %1714, %from_elements_597 : (tensor, tensor<1xi64>) -> tensor + %dim_598 = tensor.dim %1716, %c0 : tensor + %1717 = arith.index_cast %dim_598 : index to i64 + %from_elements_599 = tensor.from_elements %1717, %c1_i64 : tensor<2xi64> + %1718 = stablehlo.dynamic_reshape %1716, %from_elements_599 : (tensor, tensor<2xi64>) -> tensor + %dim_600 = tensor.dim %1718, %c0 : tensor + %1719 = arith.index_cast %dim_600 : index to i64 + %from_elements_601 = tensor.from_elements %c1_i64, %1719, %c4096_i64 : tensor<3xi64> + %1720 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_601, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_602 = tensor.dim %1720, %c1 : tensor<1x?x4096xi64> + %1721 = arith.index_cast %dim_602 : index to i64 + %from_elements_603 = tensor.from_elements %c1_i64, %1721, %c4096_i64, %c1_i64 : tensor<4xi64> + %1722 = stablehlo.dynamic_reshape %1720, %from_elements_603 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1723 = stablehlo.dynamic_broadcast_in_dim %1718, %from_elements_601, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_604 = tensor.dim %1723, %c1 : tensor<1x?x4096xi64> + %1724 = arith.index_cast %dim_604 : index to i64 + %from_elements_605 = tensor.from_elements %c1_i64, %1724, %c4096_i64, %c1_i64 : tensor<4xi64> + %1725 = stablehlo.dynamic_reshape %1723, %from_elements_605 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1726 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_601, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_606 = tensor.dim %1726, %c1 : tensor<1x?x4096xi64> + %1727 = arith.index_cast %dim_606 : index to i64 + %from_elements_607 = tensor.from_elements %c1_i64, %1727, %c4096_i64, %c1_i64 : tensor<4xi64> + %1728 = stablehlo.dynamic_reshape %1726, %from_elements_607 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1729 = stablehlo.concatenate %1722, %1725, %1728, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1730 = "stablehlo.gather"(%1400, %1729) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1731 = shape.shape_of %1730 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1732 = shape.num_elements %1731 : tensor<3xindex> -> index + %1733 = stablehlo.compute_reshape_shape %1732, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1734 = stablehlo.dynamic_reshape %1730, %1733 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1735 = stablehlo.dot %1734, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1736 = stablehlo.logistic %1735 : tensor + %1737 = shape.shape_of %1736 : tensor -> tensor<2xindex> + %1738 = shape.shape_of %1735 : tensor -> tensor<2xindex> + %1739 = shape.cstr_broadcastable %1737, %1738 : tensor<2xindex>, tensor<2xindex> + %1740 = shape.assuming %1739 -> (tensor) { + %19688 = shape.broadcast %1737, %1738 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1736, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1735, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1741 = shape.shape_of %1740 : tensor -> tensor<2xindex> + %1742 = shape.cstr_broadcastable %1741, %1738 : tensor<2xindex>, tensor<2xindex> + %1743 = shape.assuming %1742 -> (tensor) { + %19688 = shape.broadcast %1741, %1738 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1740, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1735, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1744 = stablehlo.dot %1743, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_608 = tensor.dim %1716, %c0 : tensor + %1745 = arith.index_cast %dim_608 : index to i64 + %from_elements_609 = tensor.from_elements %1745, %c1_i64 : tensor<2xi64> + %1746 = stablehlo.dynamic_reshape %1716, %from_elements_609 : (tensor, tensor<2xi64>) -> tensor + %dim_610 = tensor.dim %1713, %c0 : tensor + %1747 = arith.index_cast %dim_610 : index to i64 + %from_elements_611 = tensor.from_elements %1747, %c1_i64 : tensor<2xi64> + %1748 = stablehlo.dynamic_reshape %1713, %from_elements_611 : (tensor, tensor<2xi64>) -> tensor + %1749 = stablehlo.concatenate %1746, %1748, dim = 1 : (tensor, tensor) -> tensor + %1750 = "stablehlo.gather"(%1429, %1749) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1751 = shape.shape_of %1744 : tensor -> tensor<2xindex> + %1752 = shape.shape_of %1750 : tensor -> tensor<2xindex> + %1753 = shape.cstr_broadcastable %1751, %1752 : tensor<2xindex>, tensor<2xindex> + %1754 = shape.assuming %1753 -> (tensor) { + %19688 = shape.broadcast %1751, %1752 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1744, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1750, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1755 = shape.shape_of %1754 : tensor -> tensor<2xindex> + %1756 = stablehlo.dynamic_broadcast_in_dim %1754, %1755, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1757 = stablehlo.dynamic_broadcast_in_dim %213, %1755, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1758 = stablehlo.multiply %1756, %1757 : tensor + %dim_612 = tensor.dim %1718, %c0 : tensor + %1759 = arith.index_cast %dim_612 : index to i64 + %dim_613 = tensor.dim %1754, %c0 : tensor + %1760 = arith.index_cast %dim_613 : index to i64 + %1761 = arith.maxsi %1759, %1760 : i64 + %1762 = arith.index_cast %1761 : i64 to index + %from_elements_614 = tensor.from_elements %1762, %c4096 : tensor<2xindex> + %1763 = stablehlo.dynamic_broadcast_in_dim %1718, %from_elements_614, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_615 = tensor.dim %1763, %c0 : tensor + %1764 = arith.index_cast %dim_615 : index to i64 + %from_elements_616 = tensor.from_elements %1764, %c4096_i64 : tensor<2xi64> + %1765 = stablehlo.real_dynamic_slice %1758, %c_22, %from_elements_616, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_617 = tensor.from_elements %1764, %c4096_i64, %c1_i64 : tensor<3xi64> + %1766 = stablehlo.dynamic_reshape %1763, %from_elements_617 : (tensor, tensor<3xi64>) -> tensor + %1767 = stablehlo.dynamic_iota %from_elements_617, dim = 1 : (tensor<3xi64>) -> tensor + %1768 = stablehlo.concatenate %1766, %1767, dim = 2 : (tensor, tensor) -> tensor + %1769 = "stablehlo.scatter"(%1706, %1768, %1765) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1770 = stablehlo.slice %1389 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1771 = stablehlo.reshape %1770 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1772 = stablehlo.custom_call @byteir.non_zero(%1771) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_618 = tensor.dim %1772, %c0 : tensor + %1773 = arith.index_cast %dim_618 : index to i64 + %from_elements_619 = tensor.from_elements %1773, %c1_i64 : tensor<2xi64> + %1774 = stablehlo.real_dynamic_slice %1772, %c_22, %from_elements_619, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_620 = tensor.dim %1774, %c0 : tensor + %1775 = arith.index_cast %dim_620 : index to i64 + %from_elements_621 = tensor.from_elements %1775 : tensor<1xi64> + %1776 = stablehlo.dynamic_reshape %1774, %from_elements_621 : (tensor, tensor<1xi64>) -> tensor + %from_elements_622 = tensor.from_elements %1773, %c2_i64 : tensor<2xi64> + %1777 = stablehlo.real_dynamic_slice %1772, %c_24, %from_elements_622, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_623 = tensor.dim %1777, %c0 : tensor + %1778 = arith.index_cast %dim_623 : index to i64 + %from_elements_624 = tensor.from_elements %1778 : tensor<1xi64> + %1779 = stablehlo.dynamic_reshape %1777, %from_elements_624 : (tensor, tensor<1xi64>) -> tensor + %dim_625 = tensor.dim %1779, %c0 : tensor + %1780 = arith.index_cast %dim_625 : index to i64 + %from_elements_626 = tensor.from_elements %1780, %c1_i64 : tensor<2xi64> + %1781 = stablehlo.dynamic_reshape %1779, %from_elements_626 : (tensor, tensor<2xi64>) -> tensor + %dim_627 = tensor.dim %1781, %c0 : tensor + %1782 = arith.index_cast %dim_627 : index to i64 + %from_elements_628 = tensor.from_elements %c1_i64, %1782, %c4096_i64 : tensor<3xi64> + %1783 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_628, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_629 = tensor.dim %1783, %c1 : tensor<1x?x4096xi64> + %1784 = arith.index_cast %dim_629 : index to i64 + %from_elements_630 = tensor.from_elements %c1_i64, %1784, %c4096_i64, %c1_i64 : tensor<4xi64> + %1785 = stablehlo.dynamic_reshape %1783, %from_elements_630 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1786 = stablehlo.dynamic_broadcast_in_dim %1781, %from_elements_628, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_631 = tensor.dim %1786, %c1 : tensor<1x?x4096xi64> + %1787 = arith.index_cast %dim_631 : index to i64 + %from_elements_632 = tensor.from_elements %c1_i64, %1787, %c4096_i64, %c1_i64 : tensor<4xi64> + %1788 = stablehlo.dynamic_reshape %1786, %from_elements_632 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1789 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_628, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_633 = tensor.dim %1789, %c1 : tensor<1x?x4096xi64> + %1790 = arith.index_cast %dim_633 : index to i64 + %from_elements_634 = tensor.from_elements %c1_i64, %1790, %c4096_i64, %c1_i64 : tensor<4xi64> + %1791 = stablehlo.dynamic_reshape %1789, %from_elements_634 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1792 = stablehlo.concatenate %1785, %1788, %1791, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1793 = "stablehlo.gather"(%1400, %1792) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1794 = shape.shape_of %1793 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1795 = shape.num_elements %1794 : tensor<3xindex> -> index + %1796 = stablehlo.compute_reshape_shape %1795, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1797 = stablehlo.dynamic_reshape %1793, %1796 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1798 = stablehlo.dot %1797, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1799 = stablehlo.logistic %1798 : tensor + %1800 = shape.shape_of %1799 : tensor -> tensor<2xindex> + %1801 = shape.shape_of %1798 : tensor -> tensor<2xindex> + %1802 = shape.cstr_broadcastable %1800, %1801 : tensor<2xindex>, tensor<2xindex> + %1803 = shape.assuming %1802 -> (tensor) { + %19688 = shape.broadcast %1800, %1801 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1799, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1798, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1804 = shape.shape_of %1803 : tensor -> tensor<2xindex> + %1805 = shape.cstr_broadcastable %1804, %1801 : tensor<2xindex>, tensor<2xindex> + %1806 = shape.assuming %1805 -> (tensor) { + %19688 = shape.broadcast %1804, %1801 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1803, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1798, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1807 = stablehlo.dot %1806, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_635 = tensor.dim %1779, %c0 : tensor + %1808 = arith.index_cast %dim_635 : index to i64 + %from_elements_636 = tensor.from_elements %1808, %c1_i64 : tensor<2xi64> + %1809 = stablehlo.dynamic_reshape %1779, %from_elements_636 : (tensor, tensor<2xi64>) -> tensor + %dim_637 = tensor.dim %1776, %c0 : tensor + %1810 = arith.index_cast %dim_637 : index to i64 + %from_elements_638 = tensor.from_elements %1810, %c1_i64 : tensor<2xi64> + %1811 = stablehlo.dynamic_reshape %1776, %from_elements_638 : (tensor, tensor<2xi64>) -> tensor + %1812 = stablehlo.concatenate %1809, %1811, dim = 1 : (tensor, tensor) -> tensor + %1813 = "stablehlo.gather"(%1429, %1812) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1814 = shape.shape_of %1807 : tensor -> tensor<2xindex> + %1815 = shape.shape_of %1813 : tensor -> tensor<2xindex> + %1816 = shape.cstr_broadcastable %1814, %1815 : tensor<2xindex>, tensor<2xindex> + %1817 = shape.assuming %1816 -> (tensor) { + %19688 = shape.broadcast %1814, %1815 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1807, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1813, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1818 = shape.shape_of %1817 : tensor -> tensor<2xindex> + %1819 = stablehlo.dynamic_broadcast_in_dim %1817, %1818, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1820 = stablehlo.dynamic_broadcast_in_dim %213, %1818, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1821 = stablehlo.multiply %1819, %1820 : tensor + %dim_639 = tensor.dim %1781, %c0 : tensor + %1822 = arith.index_cast %dim_639 : index to i64 + %dim_640 = tensor.dim %1817, %c0 : tensor + %1823 = arith.index_cast %dim_640 : index to i64 + %1824 = arith.maxsi %1822, %1823 : i64 + %1825 = arith.index_cast %1824 : i64 to index + %from_elements_641 = tensor.from_elements %1825, %c4096 : tensor<2xindex> + %1826 = stablehlo.dynamic_broadcast_in_dim %1781, %from_elements_641, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_642 = tensor.dim %1826, %c0 : tensor + %1827 = arith.index_cast %dim_642 : index to i64 + %from_elements_643 = tensor.from_elements %1827, %c4096_i64 : tensor<2xi64> + %1828 = stablehlo.real_dynamic_slice %1821, %c_22, %from_elements_643, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_644 = tensor.from_elements %1827, %c4096_i64, %c1_i64 : tensor<3xi64> + %1829 = stablehlo.dynamic_reshape %1826, %from_elements_644 : (tensor, tensor<3xi64>) -> tensor + %1830 = stablehlo.dynamic_iota %from_elements_644, dim = 1 : (tensor<3xi64>) -> tensor + %1831 = stablehlo.concatenate %1829, %1830, dim = 2 : (tensor, tensor) -> tensor + %1832 = "stablehlo.scatter"(%1769, %1831, %1828) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1833 = stablehlo.slice %1389 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %1834 = stablehlo.reshape %1833 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %1835 = stablehlo.custom_call @byteir.non_zero(%1834) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_645 = tensor.dim %1835, %c0 : tensor + %1836 = arith.index_cast %dim_645 : index to i64 + %from_elements_646 = tensor.from_elements %1836, %c1_i64 : tensor<2xi64> + %1837 = stablehlo.real_dynamic_slice %1835, %c_22, %from_elements_646, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_647 = tensor.dim %1837, %c0 : tensor + %1838 = arith.index_cast %dim_647 : index to i64 + %from_elements_648 = tensor.from_elements %1838 : tensor<1xi64> + %1839 = stablehlo.dynamic_reshape %1837, %from_elements_648 : (tensor, tensor<1xi64>) -> tensor + %from_elements_649 = tensor.from_elements %1836, %c2_i64 : tensor<2xi64> + %1840 = stablehlo.real_dynamic_slice %1835, %c_24, %from_elements_649, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_650 = tensor.dim %1840, %c0 : tensor + %1841 = arith.index_cast %dim_650 : index to i64 + %from_elements_651 = tensor.from_elements %1841 : tensor<1xi64> + %1842 = stablehlo.dynamic_reshape %1840, %from_elements_651 : (tensor, tensor<1xi64>) -> tensor + %dim_652 = tensor.dim %1842, %c0 : tensor + %1843 = arith.index_cast %dim_652 : index to i64 + %from_elements_653 = tensor.from_elements %1843, %c1_i64 : tensor<2xi64> + %1844 = stablehlo.dynamic_reshape %1842, %from_elements_653 : (tensor, tensor<2xi64>) -> tensor + %dim_654 = tensor.dim %1844, %c0 : tensor + %1845 = arith.index_cast %dim_654 : index to i64 + %from_elements_655 = tensor.from_elements %c1_i64, %1845, %c4096_i64 : tensor<3xi64> + %1846 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_655, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_656 = tensor.dim %1846, %c1 : tensor<1x?x4096xi64> + %1847 = arith.index_cast %dim_656 : index to i64 + %from_elements_657 = tensor.from_elements %c1_i64, %1847, %c4096_i64, %c1_i64 : tensor<4xi64> + %1848 = stablehlo.dynamic_reshape %1846, %from_elements_657 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1849 = stablehlo.dynamic_broadcast_in_dim %1844, %from_elements_655, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_658 = tensor.dim %1849, %c1 : tensor<1x?x4096xi64> + %1850 = arith.index_cast %dim_658 : index to i64 + %from_elements_659 = tensor.from_elements %c1_i64, %1850, %c4096_i64, %c1_i64 : tensor<4xi64> + %1851 = stablehlo.dynamic_reshape %1849, %from_elements_659 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1852 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_655, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_660 = tensor.dim %1852, %c1 : tensor<1x?x4096xi64> + %1853 = arith.index_cast %dim_660 : index to i64 + %from_elements_661 = tensor.from_elements %c1_i64, %1853, %c4096_i64, %c1_i64 : tensor<4xi64> + %1854 = stablehlo.dynamic_reshape %1852, %from_elements_661 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %1855 = stablehlo.concatenate %1848, %1851, %1854, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %1856 = "stablehlo.gather"(%1400, %1855) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %1857 = shape.shape_of %1856 : tensor<1x?x4096xf32> -> tensor<3xindex> + %1858 = shape.num_elements %1857 : tensor<3xindex> -> index + %1859 = stablehlo.compute_reshape_shape %1858, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %1860 = stablehlo.dynamic_reshape %1856, %1859 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %1861 = stablehlo.dot %1860, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %1862 = stablehlo.logistic %1861 : tensor + %1863 = shape.shape_of %1862 : tensor -> tensor<2xindex> + %1864 = shape.shape_of %1861 : tensor -> tensor<2xindex> + %1865 = shape.cstr_broadcastable %1863, %1864 : tensor<2xindex>, tensor<2xindex> + %1866 = shape.assuming %1865 -> (tensor) { + %19688 = shape.broadcast %1863, %1864 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1862, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1861, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1867 = shape.shape_of %1866 : tensor -> tensor<2xindex> + %1868 = shape.cstr_broadcastable %1867, %1864 : tensor<2xindex>, tensor<2xindex> + %1869 = shape.assuming %1868 -> (tensor) { + %19688 = shape.broadcast %1867, %1864 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1866, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1861, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1870 = stablehlo.dot %1869, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_662 = tensor.dim %1842, %c0 : tensor + %1871 = arith.index_cast %dim_662 : index to i64 + %from_elements_663 = tensor.from_elements %1871, %c1_i64 : tensor<2xi64> + %1872 = stablehlo.dynamic_reshape %1842, %from_elements_663 : (tensor, tensor<2xi64>) -> tensor + %dim_664 = tensor.dim %1839, %c0 : tensor + %1873 = arith.index_cast %dim_664 : index to i64 + %from_elements_665 = tensor.from_elements %1873, %c1_i64 : tensor<2xi64> + %1874 = stablehlo.dynamic_reshape %1839, %from_elements_665 : (tensor, tensor<2xi64>) -> tensor + %1875 = stablehlo.concatenate %1872, %1874, dim = 1 : (tensor, tensor) -> tensor + %1876 = "stablehlo.gather"(%1429, %1875) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %1877 = shape.shape_of %1870 : tensor -> tensor<2xindex> + %1878 = shape.shape_of %1876 : tensor -> tensor<2xindex> + %1879 = shape.cstr_broadcastable %1877, %1878 : tensor<2xindex>, tensor<2xindex> + %1880 = shape.assuming %1879 -> (tensor) { + %19688 = shape.broadcast %1877, %1878 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %1870, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %1876, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %1881 = shape.shape_of %1880 : tensor -> tensor<2xindex> + %1882 = stablehlo.dynamic_broadcast_in_dim %1880, %1881, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %1883 = stablehlo.dynamic_broadcast_in_dim %213, %1881, dims = [] : (tensor, tensor<2xindex>) -> tensor + %1884 = stablehlo.multiply %1882, %1883 : tensor + %dim_666 = tensor.dim %1844, %c0 : tensor + %1885 = arith.index_cast %dim_666 : index to i64 + %dim_667 = tensor.dim %1880, %c0 : tensor + %1886 = arith.index_cast %dim_667 : index to i64 + %1887 = arith.maxsi %1885, %1886 : i64 + %1888 = arith.index_cast %1887 : i64 to index + %from_elements_668 = tensor.from_elements %1888, %c4096 : tensor<2xindex> + %1889 = stablehlo.dynamic_broadcast_in_dim %1844, %from_elements_668, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_669 = tensor.dim %1889, %c0 : tensor + %1890 = arith.index_cast %dim_669 : index to i64 + %from_elements_670 = tensor.from_elements %1890, %c4096_i64 : tensor<2xi64> + %1891 = stablehlo.real_dynamic_slice %1884, %c_22, %from_elements_670, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_671 = tensor.from_elements %1890, %c4096_i64, %c1_i64 : tensor<3xi64> + %1892 = stablehlo.dynamic_reshape %1889, %from_elements_671 : (tensor, tensor<3xi64>) -> tensor + %1893 = stablehlo.dynamic_iota %from_elements_671, dim = 1 : (tensor<3xi64>) -> tensor + %1894 = stablehlo.concatenate %1892, %1893, dim = 2 : (tensor, tensor) -> tensor + %1895 = "stablehlo.scatter"(%1832, %1894, %1891) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %1896 = stablehlo.reshape %1895 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %1897 = stablehlo.add %1362, %1896 : tensor<3x1x4096xf32> + %1898 = stablehlo.broadcast_in_dim %1897, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %1899 = stablehlo.power %1898, %15 : tensor<3x1x4096xf32> + %1900 = stablehlo.reduce(%1899 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %1901 = stablehlo.reshape %1900 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %1902 = stablehlo.broadcast_in_dim %1901, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %1903 = stablehlo.divide %1902, %21 : tensor<3x1x1xf32> + %1904 = stablehlo.broadcast_in_dim %1903, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %1905 = stablehlo.add %1904, %25 : tensor<3x1x1xf32> + %1906 = stablehlo.rsqrt %1905 : tensor<3x1x1xf32> + %1907 = stablehlo.broadcast_in_dim %1906, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %1908 = stablehlo.multiply %1898, %1907 : tensor<3x1x4096xf32> + %1909 = stablehlo.broadcast_in_dim %1908, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %1910 = stablehlo.multiply %1909, %31 : tensor<3x1x4096xf32> + %1911 = stablehlo.reshape %1910 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %1912 = stablehlo.dot %1911, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %1913 = stablehlo.reshape %1912 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %1914 = stablehlo.dot %1911, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %1915 = stablehlo.reshape %1914 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %1916 = stablehlo.reshape %1913 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %1917 = stablehlo.transpose %1916, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %1918 = stablehlo.reshape %1915 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %1919 = stablehlo.transpose %1918, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %1920 = stablehlo.slice %arg6 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %1921 = stablehlo.slice %arg7 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %1922 = "stablehlo.gather"(%1920, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %1923 = stablehlo.reshape %1922 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %1924 = "stablehlo.gather"(%1921, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %1925 = stablehlo.reshape %1924 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %1926 = stablehlo.broadcast_in_dim %1917, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %1927 = stablehlo.broadcast_in_dim %1923, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %1928 = stablehlo.multiply %1926, %1927 : tensor<3x32x1x128xf32> + %1929 = stablehlo.slice %1917 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %1930 = stablehlo.slice %1917 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %1931 = stablehlo.negate %1930 : tensor<3x32x1x64xf32> + %1932 = stablehlo.concatenate %1931, %1929, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %1933 = stablehlo.broadcast_in_dim %1932, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %1934 = stablehlo.broadcast_in_dim %1925, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %1935 = stablehlo.multiply %1933, %1934 : tensor<3x32x1x128xf32> + %1936 = stablehlo.add %1928, %1935 : tensor<3x32x1x128xf32> + %1937 = stablehlo.broadcast_in_dim %1919, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %1938 = stablehlo.broadcast_in_dim %1923, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %1939 = stablehlo.multiply %1937, %1938 : tensor<3x8x1x128xf32> + %1940 = stablehlo.slice %1919 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %1941 = stablehlo.slice %1919 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %1942 = stablehlo.negate %1941 : tensor<3x8x1x64xf32> + %1943 = stablehlo.concatenate %1942, %1940, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %1944 = stablehlo.broadcast_in_dim %1943, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %1945 = stablehlo.broadcast_in_dim %1925, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %1946 = stablehlo.multiply %1944, %1945 : tensor<3x8x1x128xf32> + %1947 = stablehlo.add %1939, %1946 : tensor<3x8x1x128xf32> + %1948 = stablehlo.concatenate %arg71, %1947, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %1949 = stablehlo.concatenate %arg72, %1919, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %1950 = stablehlo.reshape %1948 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %1951 = stablehlo.broadcast_in_dim %1950, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %1952 = stablehlo.reshape %1951 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %1953 = stablehlo.reshape %1949 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %1954 = stablehlo.broadcast_in_dim %1953, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %1955 = stablehlo.reshape %1954 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %1956 = stablehlo.transpose %1952, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %1957 = stablehlo.reshape %1936 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %1958 = stablehlo.reshape %1956 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %1959 = stablehlo.broadcast_in_dim %1958, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %1960 = stablehlo.dot_general %1957, %1959, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %1961 = stablehlo.reshape %1960 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %1962 = stablehlo.broadcast_in_dim %1961, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %1963 = stablehlo.divide %1962, %89 : tensor<3x32x1x8xf32> + %1964 = stablehlo.custom_call @byteir.softmax(%1963) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %1965 = stablehlo.reshape %1964 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %1966 = stablehlo.reshape %1955 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %1967 = stablehlo.broadcast_in_dim %1966, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %1968 = stablehlo.dot_general %1965, %1967, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %1969 = stablehlo.reshape %1968 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %1970 = stablehlo.transpose %1969, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %1971 = stablehlo.reshape %1970 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %1972 = stablehlo.reshape %1971 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %1973 = stablehlo.dot %1972, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %1974 = stablehlo.reshape %1973 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %1975 = stablehlo.add %1897, %1974 : tensor<3x1x4096xf32> + %1976 = stablehlo.broadcast_in_dim %1975, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %1977 = stablehlo.power %1976, %15 : tensor<3x1x4096xf32> + %1978 = stablehlo.reduce(%1977 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %1979 = stablehlo.reshape %1978 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %1980 = stablehlo.broadcast_in_dim %1979, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %1981 = stablehlo.divide %1980, %21 : tensor<3x1x1xf32> + %1982 = stablehlo.broadcast_in_dim %1981, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %1983 = stablehlo.add %1982, %25 : tensor<3x1x1xf32> + %1984 = stablehlo.rsqrt %1983 : tensor<3x1x1xf32> + %1985 = stablehlo.broadcast_in_dim %1984, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %1986 = stablehlo.multiply %1976, %1985 : tensor<3x1x4096xf32> + %1987 = stablehlo.broadcast_in_dim %1986, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %1988 = stablehlo.multiply %1987, %31 : tensor<3x1x4096xf32> + %1989 = stablehlo.reshape %1988 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %1990 = stablehlo.dot %1989, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %1991 = stablehlo.custom_call @byteir.softmax(%1990) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %1992:2 = stablehlo.custom_call @byteir.top_k(%1991) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %1993 = stablehlo.reduce(%1992#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %1994 = stablehlo.reshape %1993 : (tensor<3xf32>) -> tensor<3x1xf32> + %1995 = stablehlo.broadcast_in_dim %1992#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %1996 = stablehlo.broadcast_in_dim %1994, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %1997 = stablehlo.divide %1995, %1996 : tensor<3x2xf32> + %1998 = stablehlo.reshape %1992#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %1999 = stablehlo.broadcast_in_dim %1998, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %2000 = stablehlo.compare EQ, %1999, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %2001 = stablehlo.convert %2000 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %2002 = stablehlo.transpose %2001, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %2003 = stablehlo.slice %2002 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2004 = stablehlo.reshape %2003 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2005 = stablehlo.custom_call @byteir.non_zero(%2004) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_672 = tensor.dim %2005, %c0 : tensor + %2006 = arith.index_cast %dim_672 : index to i64 + %from_elements_673 = tensor.from_elements %2006, %c1_i64 : tensor<2xi64> + %2007 = stablehlo.real_dynamic_slice %2005, %c_22, %from_elements_673, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_674 = tensor.dim %2007, %c0 : tensor + %2008 = arith.index_cast %dim_674 : index to i64 + %from_elements_675 = tensor.from_elements %2008 : tensor<1xi64> + %2009 = stablehlo.dynamic_reshape %2007, %from_elements_675 : (tensor, tensor<1xi64>) -> tensor + %from_elements_676 = tensor.from_elements %2006, %c2_i64 : tensor<2xi64> + %2010 = stablehlo.real_dynamic_slice %2005, %c_24, %from_elements_676, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_677 = tensor.dim %2010, %c0 : tensor + %2011 = arith.index_cast %dim_677 : index to i64 + %from_elements_678 = tensor.from_elements %2011 : tensor<1xi64> + %2012 = stablehlo.dynamic_reshape %2010, %from_elements_678 : (tensor, tensor<1xi64>) -> tensor + %2013 = stablehlo.reshape %1989 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_679 = tensor.dim %2012, %c0 : tensor + %2014 = arith.index_cast %dim_679 : index to i64 + %from_elements_680 = tensor.from_elements %2014, %c1_i64 : tensor<2xi64> + %2015 = stablehlo.dynamic_reshape %2012, %from_elements_680 : (tensor, tensor<2xi64>) -> tensor + %dim_681 = tensor.dim %2015, %c0 : tensor + %2016 = arith.index_cast %dim_681 : index to i64 + %from_elements_682 = tensor.from_elements %c1_i64, %2016, %c4096_i64 : tensor<3xi64> + %2017 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_682, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_683 = tensor.dim %2017, %c1 : tensor<1x?x4096xi64> + %2018 = arith.index_cast %dim_683 : index to i64 + %from_elements_684 = tensor.from_elements %c1_i64, %2018, %c4096_i64, %c1_i64 : tensor<4xi64> + %2019 = stablehlo.dynamic_reshape %2017, %from_elements_684 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2020 = stablehlo.dynamic_broadcast_in_dim %2015, %from_elements_682, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_685 = tensor.dim %2020, %c1 : tensor<1x?x4096xi64> + %2021 = arith.index_cast %dim_685 : index to i64 + %from_elements_686 = tensor.from_elements %c1_i64, %2021, %c4096_i64, %c1_i64 : tensor<4xi64> + %2022 = stablehlo.dynamic_reshape %2020, %from_elements_686 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2023 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_682, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_687 = tensor.dim %2023, %c1 : tensor<1x?x4096xi64> + %2024 = arith.index_cast %dim_687 : index to i64 + %from_elements_688 = tensor.from_elements %c1_i64, %2024, %c4096_i64, %c1_i64 : tensor<4xi64> + %2025 = stablehlo.dynamic_reshape %2023, %from_elements_688 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2026 = stablehlo.concatenate %2019, %2022, %2025, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2027 = "stablehlo.gather"(%2013, %2026) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2028 = shape.shape_of %2027 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2029 = shape.num_elements %2028 : tensor<3xindex> -> index + %2030 = stablehlo.compute_reshape_shape %2029, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2031 = stablehlo.dynamic_reshape %2027, %2030 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2032 = stablehlo.dot %2031, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2033 = stablehlo.logistic %2032 : tensor + %2034 = shape.shape_of %2033 : tensor -> tensor<2xindex> + %2035 = shape.shape_of %2032 : tensor -> tensor<2xindex> + %2036 = shape.cstr_broadcastable %2034, %2035 : tensor<2xindex>, tensor<2xindex> + %2037 = shape.assuming %2036 -> (tensor) { + %19688 = shape.broadcast %2034, %2035 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2033, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2032, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2038 = shape.shape_of %2037 : tensor -> tensor<2xindex> + %2039 = shape.cstr_broadcastable %2038, %2035 : tensor<2xindex>, tensor<2xindex> + %2040 = shape.assuming %2039 -> (tensor) { + %19688 = shape.broadcast %2038, %2035 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2037, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2032, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2041 = stablehlo.dot %2040, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %2042 = stablehlo.reshape %1997 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_689 = tensor.dim %2012, %c0 : tensor + %2043 = arith.index_cast %dim_689 : index to i64 + %from_elements_690 = tensor.from_elements %2043, %c1_i64 : tensor<2xi64> + %2044 = stablehlo.dynamic_reshape %2012, %from_elements_690 : (tensor, tensor<2xi64>) -> tensor + %dim_691 = tensor.dim %2009, %c0 : tensor + %2045 = arith.index_cast %dim_691 : index to i64 + %from_elements_692 = tensor.from_elements %2045, %c1_i64 : tensor<2xi64> + %2046 = stablehlo.dynamic_reshape %2009, %from_elements_692 : (tensor, tensor<2xi64>) -> tensor + %2047 = stablehlo.concatenate %2044, %2046, dim = 1 : (tensor, tensor) -> tensor + %2048 = "stablehlo.gather"(%2042, %2047) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2049 = shape.shape_of %2041 : tensor -> tensor<2xindex> + %2050 = shape.shape_of %2048 : tensor -> tensor<2xindex> + %2051 = shape.cstr_broadcastable %2049, %2050 : tensor<2xindex>, tensor<2xindex> + %2052 = shape.assuming %2051 -> (tensor) { + %19688 = shape.broadcast %2049, %2050 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2041, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2048, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2053 = shape.shape_of %2052 : tensor -> tensor<2xindex> + %2054 = stablehlo.dynamic_broadcast_in_dim %2052, %2053, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2055 = stablehlo.dynamic_broadcast_in_dim %213, %2053, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2056 = stablehlo.multiply %2054, %2055 : tensor + %dim_693 = tensor.dim %2015, %c0 : tensor + %2057 = arith.index_cast %dim_693 : index to i64 + %dim_694 = tensor.dim %2052, %c0 : tensor + %2058 = arith.index_cast %dim_694 : index to i64 + %2059 = arith.maxsi %2057, %2058 : i64 + %2060 = arith.index_cast %2059 : i64 to index + %from_elements_695 = tensor.from_elements %2060, %c4096 : tensor<2xindex> + %2061 = stablehlo.dynamic_broadcast_in_dim %2015, %from_elements_695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_696 = tensor.dim %2061, %c0 : tensor + %2062 = arith.index_cast %dim_696 : index to i64 + %from_elements_697 = tensor.from_elements %2062, %c4096_i64 : tensor<2xi64> + %2063 = stablehlo.real_dynamic_slice %2056, %c_22, %from_elements_697, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_698 = tensor.from_elements %2062, %c4096_i64, %c1_i64 : tensor<3xi64> + %2064 = stablehlo.dynamic_reshape %2061, %from_elements_698 : (tensor, tensor<3xi64>) -> tensor + %2065 = stablehlo.dynamic_iota %from_elements_698, dim = 1 : (tensor<3xi64>) -> tensor + %2066 = stablehlo.concatenate %2064, %2065, dim = 2 : (tensor, tensor) -> tensor + %2067 = "stablehlo.scatter"(%cst_2, %2066, %2063) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2068 = stablehlo.slice %2002 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2069 = stablehlo.reshape %2068 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2070 = stablehlo.custom_call @byteir.non_zero(%2069) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_699 = tensor.dim %2070, %c0 : tensor + %2071 = arith.index_cast %dim_699 : index to i64 + %from_elements_700 = tensor.from_elements %2071, %c1_i64 : tensor<2xi64> + %2072 = stablehlo.real_dynamic_slice %2070, %c_22, %from_elements_700, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_701 = tensor.dim %2072, %c0 : tensor + %2073 = arith.index_cast %dim_701 : index to i64 + %from_elements_702 = tensor.from_elements %2073 : tensor<1xi64> + %2074 = stablehlo.dynamic_reshape %2072, %from_elements_702 : (tensor, tensor<1xi64>) -> tensor + %from_elements_703 = tensor.from_elements %2071, %c2_i64 : tensor<2xi64> + %2075 = stablehlo.real_dynamic_slice %2070, %c_24, %from_elements_703, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_704 = tensor.dim %2075, %c0 : tensor + %2076 = arith.index_cast %dim_704 : index to i64 + %from_elements_705 = tensor.from_elements %2076 : tensor<1xi64> + %2077 = stablehlo.dynamic_reshape %2075, %from_elements_705 : (tensor, tensor<1xi64>) -> tensor + %dim_706 = tensor.dim %2077, %c0 : tensor + %2078 = arith.index_cast %dim_706 : index to i64 + %from_elements_707 = tensor.from_elements %2078, %c1_i64 : tensor<2xi64> + %2079 = stablehlo.dynamic_reshape %2077, %from_elements_707 : (tensor, tensor<2xi64>) -> tensor + %dim_708 = tensor.dim %2079, %c0 : tensor + %2080 = arith.index_cast %dim_708 : index to i64 + %from_elements_709 = tensor.from_elements %c1_i64, %2080, %c4096_i64 : tensor<3xi64> + %2081 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_709, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_710 = tensor.dim %2081, %c1 : tensor<1x?x4096xi64> + %2082 = arith.index_cast %dim_710 : index to i64 + %from_elements_711 = tensor.from_elements %c1_i64, %2082, %c4096_i64, %c1_i64 : tensor<4xi64> + %2083 = stablehlo.dynamic_reshape %2081, %from_elements_711 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2084 = stablehlo.dynamic_broadcast_in_dim %2079, %from_elements_709, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_712 = tensor.dim %2084, %c1 : tensor<1x?x4096xi64> + %2085 = arith.index_cast %dim_712 : index to i64 + %from_elements_713 = tensor.from_elements %c1_i64, %2085, %c4096_i64, %c1_i64 : tensor<4xi64> + %2086 = stablehlo.dynamic_reshape %2084, %from_elements_713 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2087 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_709, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_714 = tensor.dim %2087, %c1 : tensor<1x?x4096xi64> + %2088 = arith.index_cast %dim_714 : index to i64 + %from_elements_715 = tensor.from_elements %c1_i64, %2088, %c4096_i64, %c1_i64 : tensor<4xi64> + %2089 = stablehlo.dynamic_reshape %2087, %from_elements_715 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2090 = stablehlo.concatenate %2083, %2086, %2089, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2091 = "stablehlo.gather"(%2013, %2090) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2092 = shape.shape_of %2091 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2093 = shape.num_elements %2092 : tensor<3xindex> -> index + %2094 = stablehlo.compute_reshape_shape %2093, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2095 = stablehlo.dynamic_reshape %2091, %2094 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2096 = stablehlo.dot %2095, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2097 = stablehlo.logistic %2096 : tensor + %2098 = shape.shape_of %2097 : tensor -> tensor<2xindex> + %2099 = shape.shape_of %2096 : tensor -> tensor<2xindex> + %2100 = shape.cstr_broadcastable %2098, %2099 : tensor<2xindex>, tensor<2xindex> + %2101 = shape.assuming %2100 -> (tensor) { + %19688 = shape.broadcast %2098, %2099 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2097, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2096, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2102 = shape.shape_of %2101 : tensor -> tensor<2xindex> + %2103 = shape.cstr_broadcastable %2102, %2099 : tensor<2xindex>, tensor<2xindex> + %2104 = shape.assuming %2103 -> (tensor) { + %19688 = shape.broadcast %2102, %2099 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2101, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2096, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2105 = stablehlo.dot %2104, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_716 = tensor.dim %2077, %c0 : tensor + %2106 = arith.index_cast %dim_716 : index to i64 + %from_elements_717 = tensor.from_elements %2106, %c1_i64 : tensor<2xi64> + %2107 = stablehlo.dynamic_reshape %2077, %from_elements_717 : (tensor, tensor<2xi64>) -> tensor + %dim_718 = tensor.dim %2074, %c0 : tensor + %2108 = arith.index_cast %dim_718 : index to i64 + %from_elements_719 = tensor.from_elements %2108, %c1_i64 : tensor<2xi64> + %2109 = stablehlo.dynamic_reshape %2074, %from_elements_719 : (tensor, tensor<2xi64>) -> tensor + %2110 = stablehlo.concatenate %2107, %2109, dim = 1 : (tensor, tensor) -> tensor + %2111 = "stablehlo.gather"(%2042, %2110) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2112 = shape.shape_of %2105 : tensor -> tensor<2xindex> + %2113 = shape.shape_of %2111 : tensor -> tensor<2xindex> + %2114 = shape.cstr_broadcastable %2112, %2113 : tensor<2xindex>, tensor<2xindex> + %2115 = shape.assuming %2114 -> (tensor) { + %19688 = shape.broadcast %2112, %2113 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2105, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2111, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2116 = shape.shape_of %2115 : tensor -> tensor<2xindex> + %2117 = stablehlo.dynamic_broadcast_in_dim %2115, %2116, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2118 = stablehlo.dynamic_broadcast_in_dim %213, %2116, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2119 = stablehlo.multiply %2117, %2118 : tensor + %dim_720 = tensor.dim %2079, %c0 : tensor + %2120 = arith.index_cast %dim_720 : index to i64 + %dim_721 = tensor.dim %2115, %c0 : tensor + %2121 = arith.index_cast %dim_721 : index to i64 + %2122 = arith.maxsi %2120, %2121 : i64 + %2123 = arith.index_cast %2122 : i64 to index + %from_elements_722 = tensor.from_elements %2123, %c4096 : tensor<2xindex> + %2124 = stablehlo.dynamic_broadcast_in_dim %2079, %from_elements_722, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_723 = tensor.dim %2124, %c0 : tensor + %2125 = arith.index_cast %dim_723 : index to i64 + %from_elements_724 = tensor.from_elements %2125, %c4096_i64 : tensor<2xi64> + %2126 = stablehlo.real_dynamic_slice %2119, %c_22, %from_elements_724, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_725 = tensor.from_elements %2125, %c4096_i64, %c1_i64 : tensor<3xi64> + %2127 = stablehlo.dynamic_reshape %2124, %from_elements_725 : (tensor, tensor<3xi64>) -> tensor + %2128 = stablehlo.dynamic_iota %from_elements_725, dim = 1 : (tensor<3xi64>) -> tensor + %2129 = stablehlo.concatenate %2127, %2128, dim = 2 : (tensor, tensor) -> tensor + %2130 = "stablehlo.scatter"(%2067, %2129, %2126) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2131 = stablehlo.slice %2002 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2132 = stablehlo.reshape %2131 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2133 = stablehlo.custom_call @byteir.non_zero(%2132) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_726 = tensor.dim %2133, %c0 : tensor + %2134 = arith.index_cast %dim_726 : index to i64 + %from_elements_727 = tensor.from_elements %2134, %c1_i64 : tensor<2xi64> + %2135 = stablehlo.real_dynamic_slice %2133, %c_22, %from_elements_727, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_728 = tensor.dim %2135, %c0 : tensor + %2136 = arith.index_cast %dim_728 : index to i64 + %from_elements_729 = tensor.from_elements %2136 : tensor<1xi64> + %2137 = stablehlo.dynamic_reshape %2135, %from_elements_729 : (tensor, tensor<1xi64>) -> tensor + %from_elements_730 = tensor.from_elements %2134, %c2_i64 : tensor<2xi64> + %2138 = stablehlo.real_dynamic_slice %2133, %c_24, %from_elements_730, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_731 = tensor.dim %2138, %c0 : tensor + %2139 = arith.index_cast %dim_731 : index to i64 + %from_elements_732 = tensor.from_elements %2139 : tensor<1xi64> + %2140 = stablehlo.dynamic_reshape %2138, %from_elements_732 : (tensor, tensor<1xi64>) -> tensor + %dim_733 = tensor.dim %2140, %c0 : tensor + %2141 = arith.index_cast %dim_733 : index to i64 + %from_elements_734 = tensor.from_elements %2141, %c1_i64 : tensor<2xi64> + %2142 = stablehlo.dynamic_reshape %2140, %from_elements_734 : (tensor, tensor<2xi64>) -> tensor + %dim_735 = tensor.dim %2142, %c0 : tensor + %2143 = arith.index_cast %dim_735 : index to i64 + %from_elements_736 = tensor.from_elements %c1_i64, %2143, %c4096_i64 : tensor<3xi64> + %2144 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_736, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_737 = tensor.dim %2144, %c1 : tensor<1x?x4096xi64> + %2145 = arith.index_cast %dim_737 : index to i64 + %from_elements_738 = tensor.from_elements %c1_i64, %2145, %c4096_i64, %c1_i64 : tensor<4xi64> + %2146 = stablehlo.dynamic_reshape %2144, %from_elements_738 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2147 = stablehlo.dynamic_broadcast_in_dim %2142, %from_elements_736, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_739 = tensor.dim %2147, %c1 : tensor<1x?x4096xi64> + %2148 = arith.index_cast %dim_739 : index to i64 + %from_elements_740 = tensor.from_elements %c1_i64, %2148, %c4096_i64, %c1_i64 : tensor<4xi64> + %2149 = stablehlo.dynamic_reshape %2147, %from_elements_740 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2150 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_736, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_741 = tensor.dim %2150, %c1 : tensor<1x?x4096xi64> + %2151 = arith.index_cast %dim_741 : index to i64 + %from_elements_742 = tensor.from_elements %c1_i64, %2151, %c4096_i64, %c1_i64 : tensor<4xi64> + %2152 = stablehlo.dynamic_reshape %2150, %from_elements_742 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2153 = stablehlo.concatenate %2146, %2149, %2152, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2154 = "stablehlo.gather"(%2013, %2153) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2155 = shape.shape_of %2154 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2156 = shape.num_elements %2155 : tensor<3xindex> -> index + %2157 = stablehlo.compute_reshape_shape %2156, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2158 = stablehlo.dynamic_reshape %2154, %2157 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2159 = stablehlo.dot %2158, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2160 = stablehlo.logistic %2159 : tensor + %2161 = shape.shape_of %2160 : tensor -> tensor<2xindex> + %2162 = shape.shape_of %2159 : tensor -> tensor<2xindex> + %2163 = shape.cstr_broadcastable %2161, %2162 : tensor<2xindex>, tensor<2xindex> + %2164 = shape.assuming %2163 -> (tensor) { + %19688 = shape.broadcast %2161, %2162 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2160, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2159, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2165 = shape.shape_of %2164 : tensor -> tensor<2xindex> + %2166 = shape.cstr_broadcastable %2165, %2162 : tensor<2xindex>, tensor<2xindex> + %2167 = shape.assuming %2166 -> (tensor) { + %19688 = shape.broadcast %2165, %2162 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2164, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2159, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2168 = stablehlo.dot %2167, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_743 = tensor.dim %2140, %c0 : tensor + %2169 = arith.index_cast %dim_743 : index to i64 + %from_elements_744 = tensor.from_elements %2169, %c1_i64 : tensor<2xi64> + %2170 = stablehlo.dynamic_reshape %2140, %from_elements_744 : (tensor, tensor<2xi64>) -> tensor + %dim_745 = tensor.dim %2137, %c0 : tensor + %2171 = arith.index_cast %dim_745 : index to i64 + %from_elements_746 = tensor.from_elements %2171, %c1_i64 : tensor<2xi64> + %2172 = stablehlo.dynamic_reshape %2137, %from_elements_746 : (tensor, tensor<2xi64>) -> tensor + %2173 = stablehlo.concatenate %2170, %2172, dim = 1 : (tensor, tensor) -> tensor + %2174 = "stablehlo.gather"(%2042, %2173) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2175 = shape.shape_of %2168 : tensor -> tensor<2xindex> + %2176 = shape.shape_of %2174 : tensor -> tensor<2xindex> + %2177 = shape.cstr_broadcastable %2175, %2176 : tensor<2xindex>, tensor<2xindex> + %2178 = shape.assuming %2177 -> (tensor) { + %19688 = shape.broadcast %2175, %2176 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2168, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2174, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2179 = shape.shape_of %2178 : tensor -> tensor<2xindex> + %2180 = stablehlo.dynamic_broadcast_in_dim %2178, %2179, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2181 = stablehlo.dynamic_broadcast_in_dim %213, %2179, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2182 = stablehlo.multiply %2180, %2181 : tensor + %dim_747 = tensor.dim %2142, %c0 : tensor + %2183 = arith.index_cast %dim_747 : index to i64 + %dim_748 = tensor.dim %2178, %c0 : tensor + %2184 = arith.index_cast %dim_748 : index to i64 + %2185 = arith.maxsi %2183, %2184 : i64 + %2186 = arith.index_cast %2185 : i64 to index + %from_elements_749 = tensor.from_elements %2186, %c4096 : tensor<2xindex> + %2187 = stablehlo.dynamic_broadcast_in_dim %2142, %from_elements_749, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_750 = tensor.dim %2187, %c0 : tensor + %2188 = arith.index_cast %dim_750 : index to i64 + %from_elements_751 = tensor.from_elements %2188, %c4096_i64 : tensor<2xi64> + %2189 = stablehlo.real_dynamic_slice %2182, %c_22, %from_elements_751, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_752 = tensor.from_elements %2188, %c4096_i64, %c1_i64 : tensor<3xi64> + %2190 = stablehlo.dynamic_reshape %2187, %from_elements_752 : (tensor, tensor<3xi64>) -> tensor + %2191 = stablehlo.dynamic_iota %from_elements_752, dim = 1 : (tensor<3xi64>) -> tensor + %2192 = stablehlo.concatenate %2190, %2191, dim = 2 : (tensor, tensor) -> tensor + %2193 = "stablehlo.scatter"(%2130, %2192, %2189) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2194 = stablehlo.slice %2002 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2195 = stablehlo.reshape %2194 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2196 = stablehlo.custom_call @byteir.non_zero(%2195) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_753 = tensor.dim %2196, %c0 : tensor + %2197 = arith.index_cast %dim_753 : index to i64 + %from_elements_754 = tensor.from_elements %2197, %c1_i64 : tensor<2xi64> + %2198 = stablehlo.real_dynamic_slice %2196, %c_22, %from_elements_754, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_755 = tensor.dim %2198, %c0 : tensor + %2199 = arith.index_cast %dim_755 : index to i64 + %from_elements_756 = tensor.from_elements %2199 : tensor<1xi64> + %2200 = stablehlo.dynamic_reshape %2198, %from_elements_756 : (tensor, tensor<1xi64>) -> tensor + %from_elements_757 = tensor.from_elements %2197, %c2_i64 : tensor<2xi64> + %2201 = stablehlo.real_dynamic_slice %2196, %c_24, %from_elements_757, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_758 = tensor.dim %2201, %c0 : tensor + %2202 = arith.index_cast %dim_758 : index to i64 + %from_elements_759 = tensor.from_elements %2202 : tensor<1xi64> + %2203 = stablehlo.dynamic_reshape %2201, %from_elements_759 : (tensor, tensor<1xi64>) -> tensor + %dim_760 = tensor.dim %2203, %c0 : tensor + %2204 = arith.index_cast %dim_760 : index to i64 + %from_elements_761 = tensor.from_elements %2204, %c1_i64 : tensor<2xi64> + %2205 = stablehlo.dynamic_reshape %2203, %from_elements_761 : (tensor, tensor<2xi64>) -> tensor + %dim_762 = tensor.dim %2205, %c0 : tensor + %2206 = arith.index_cast %dim_762 : index to i64 + %from_elements_763 = tensor.from_elements %c1_i64, %2206, %c4096_i64 : tensor<3xi64> + %2207 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_763, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_764 = tensor.dim %2207, %c1 : tensor<1x?x4096xi64> + %2208 = arith.index_cast %dim_764 : index to i64 + %from_elements_765 = tensor.from_elements %c1_i64, %2208, %c4096_i64, %c1_i64 : tensor<4xi64> + %2209 = stablehlo.dynamic_reshape %2207, %from_elements_765 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2210 = stablehlo.dynamic_broadcast_in_dim %2205, %from_elements_763, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_766 = tensor.dim %2210, %c1 : tensor<1x?x4096xi64> + %2211 = arith.index_cast %dim_766 : index to i64 + %from_elements_767 = tensor.from_elements %c1_i64, %2211, %c4096_i64, %c1_i64 : tensor<4xi64> + %2212 = stablehlo.dynamic_reshape %2210, %from_elements_767 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2213 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_763, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_768 = tensor.dim %2213, %c1 : tensor<1x?x4096xi64> + %2214 = arith.index_cast %dim_768 : index to i64 + %from_elements_769 = tensor.from_elements %c1_i64, %2214, %c4096_i64, %c1_i64 : tensor<4xi64> + %2215 = stablehlo.dynamic_reshape %2213, %from_elements_769 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2216 = stablehlo.concatenate %2209, %2212, %2215, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2217 = "stablehlo.gather"(%2013, %2216) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2218 = shape.shape_of %2217 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2219 = shape.num_elements %2218 : tensor<3xindex> -> index + %2220 = stablehlo.compute_reshape_shape %2219, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2221 = stablehlo.dynamic_reshape %2217, %2220 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2222 = stablehlo.dot %2221, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2223 = stablehlo.logistic %2222 : tensor + %2224 = shape.shape_of %2223 : tensor -> tensor<2xindex> + %2225 = shape.shape_of %2222 : tensor -> tensor<2xindex> + %2226 = shape.cstr_broadcastable %2224, %2225 : tensor<2xindex>, tensor<2xindex> + %2227 = shape.assuming %2226 -> (tensor) { + %19688 = shape.broadcast %2224, %2225 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2223, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2222, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2228 = shape.shape_of %2227 : tensor -> tensor<2xindex> + %2229 = shape.cstr_broadcastable %2228, %2225 : tensor<2xindex>, tensor<2xindex> + %2230 = shape.assuming %2229 -> (tensor) { + %19688 = shape.broadcast %2228, %2225 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2227, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2222, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2231 = stablehlo.dot %2230, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_770 = tensor.dim %2203, %c0 : tensor + %2232 = arith.index_cast %dim_770 : index to i64 + %from_elements_771 = tensor.from_elements %2232, %c1_i64 : tensor<2xi64> + %2233 = stablehlo.dynamic_reshape %2203, %from_elements_771 : (tensor, tensor<2xi64>) -> tensor + %dim_772 = tensor.dim %2200, %c0 : tensor + %2234 = arith.index_cast %dim_772 : index to i64 + %from_elements_773 = tensor.from_elements %2234, %c1_i64 : tensor<2xi64> + %2235 = stablehlo.dynamic_reshape %2200, %from_elements_773 : (tensor, tensor<2xi64>) -> tensor + %2236 = stablehlo.concatenate %2233, %2235, dim = 1 : (tensor, tensor) -> tensor + %2237 = "stablehlo.gather"(%2042, %2236) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2238 = shape.shape_of %2231 : tensor -> tensor<2xindex> + %2239 = shape.shape_of %2237 : tensor -> tensor<2xindex> + %2240 = shape.cstr_broadcastable %2238, %2239 : tensor<2xindex>, tensor<2xindex> + %2241 = shape.assuming %2240 -> (tensor) { + %19688 = shape.broadcast %2238, %2239 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2231, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2237, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2242 = shape.shape_of %2241 : tensor -> tensor<2xindex> + %2243 = stablehlo.dynamic_broadcast_in_dim %2241, %2242, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2244 = stablehlo.dynamic_broadcast_in_dim %213, %2242, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2245 = stablehlo.multiply %2243, %2244 : tensor + %dim_774 = tensor.dim %2205, %c0 : tensor + %2246 = arith.index_cast %dim_774 : index to i64 + %dim_775 = tensor.dim %2241, %c0 : tensor + %2247 = arith.index_cast %dim_775 : index to i64 + %2248 = arith.maxsi %2246, %2247 : i64 + %2249 = arith.index_cast %2248 : i64 to index + %from_elements_776 = tensor.from_elements %2249, %c4096 : tensor<2xindex> + %2250 = stablehlo.dynamic_broadcast_in_dim %2205, %from_elements_776, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_777 = tensor.dim %2250, %c0 : tensor + %2251 = arith.index_cast %dim_777 : index to i64 + %from_elements_778 = tensor.from_elements %2251, %c4096_i64 : tensor<2xi64> + %2252 = stablehlo.real_dynamic_slice %2245, %c_22, %from_elements_778, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_779 = tensor.from_elements %2251, %c4096_i64, %c1_i64 : tensor<3xi64> + %2253 = stablehlo.dynamic_reshape %2250, %from_elements_779 : (tensor, tensor<3xi64>) -> tensor + %2254 = stablehlo.dynamic_iota %from_elements_779, dim = 1 : (tensor<3xi64>) -> tensor + %2255 = stablehlo.concatenate %2253, %2254, dim = 2 : (tensor, tensor) -> tensor + %2256 = "stablehlo.scatter"(%2193, %2255, %2252) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2257 = stablehlo.slice %2002 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2258 = stablehlo.reshape %2257 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2259 = stablehlo.custom_call @byteir.non_zero(%2258) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_780 = tensor.dim %2259, %c0 : tensor + %2260 = arith.index_cast %dim_780 : index to i64 + %from_elements_781 = tensor.from_elements %2260, %c1_i64 : tensor<2xi64> + %2261 = stablehlo.real_dynamic_slice %2259, %c_22, %from_elements_781, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_782 = tensor.dim %2261, %c0 : tensor + %2262 = arith.index_cast %dim_782 : index to i64 + %from_elements_783 = tensor.from_elements %2262 : tensor<1xi64> + %2263 = stablehlo.dynamic_reshape %2261, %from_elements_783 : (tensor, tensor<1xi64>) -> tensor + %from_elements_784 = tensor.from_elements %2260, %c2_i64 : tensor<2xi64> + %2264 = stablehlo.real_dynamic_slice %2259, %c_24, %from_elements_784, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_785 = tensor.dim %2264, %c0 : tensor + %2265 = arith.index_cast %dim_785 : index to i64 + %from_elements_786 = tensor.from_elements %2265 : tensor<1xi64> + %2266 = stablehlo.dynamic_reshape %2264, %from_elements_786 : (tensor, tensor<1xi64>) -> tensor + %dim_787 = tensor.dim %2266, %c0 : tensor + %2267 = arith.index_cast %dim_787 : index to i64 + %from_elements_788 = tensor.from_elements %2267, %c1_i64 : tensor<2xi64> + %2268 = stablehlo.dynamic_reshape %2266, %from_elements_788 : (tensor, tensor<2xi64>) -> tensor + %dim_789 = tensor.dim %2268, %c0 : tensor + %2269 = arith.index_cast %dim_789 : index to i64 + %from_elements_790 = tensor.from_elements %c1_i64, %2269, %c4096_i64 : tensor<3xi64> + %2270 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_790, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_791 = tensor.dim %2270, %c1 : tensor<1x?x4096xi64> + %2271 = arith.index_cast %dim_791 : index to i64 + %from_elements_792 = tensor.from_elements %c1_i64, %2271, %c4096_i64, %c1_i64 : tensor<4xi64> + %2272 = stablehlo.dynamic_reshape %2270, %from_elements_792 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2273 = stablehlo.dynamic_broadcast_in_dim %2268, %from_elements_790, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_793 = tensor.dim %2273, %c1 : tensor<1x?x4096xi64> + %2274 = arith.index_cast %dim_793 : index to i64 + %from_elements_794 = tensor.from_elements %c1_i64, %2274, %c4096_i64, %c1_i64 : tensor<4xi64> + %2275 = stablehlo.dynamic_reshape %2273, %from_elements_794 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2276 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_790, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_795 = tensor.dim %2276, %c1 : tensor<1x?x4096xi64> + %2277 = arith.index_cast %dim_795 : index to i64 + %from_elements_796 = tensor.from_elements %c1_i64, %2277, %c4096_i64, %c1_i64 : tensor<4xi64> + %2278 = stablehlo.dynamic_reshape %2276, %from_elements_796 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2279 = stablehlo.concatenate %2272, %2275, %2278, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2280 = "stablehlo.gather"(%2013, %2279) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2281 = shape.shape_of %2280 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2282 = shape.num_elements %2281 : tensor<3xindex> -> index + %2283 = stablehlo.compute_reshape_shape %2282, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2284 = stablehlo.dynamic_reshape %2280, %2283 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2285 = stablehlo.dot %2284, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2286 = stablehlo.logistic %2285 : tensor + %2287 = shape.shape_of %2286 : tensor -> tensor<2xindex> + %2288 = shape.shape_of %2285 : tensor -> tensor<2xindex> + %2289 = shape.cstr_broadcastable %2287, %2288 : tensor<2xindex>, tensor<2xindex> + %2290 = shape.assuming %2289 -> (tensor) { + %19688 = shape.broadcast %2287, %2288 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2286, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2285, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2291 = shape.shape_of %2290 : tensor -> tensor<2xindex> + %2292 = shape.cstr_broadcastable %2291, %2288 : tensor<2xindex>, tensor<2xindex> + %2293 = shape.assuming %2292 -> (tensor) { + %19688 = shape.broadcast %2291, %2288 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2290, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2285, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2294 = stablehlo.dot %2293, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_797 = tensor.dim %2266, %c0 : tensor + %2295 = arith.index_cast %dim_797 : index to i64 + %from_elements_798 = tensor.from_elements %2295, %c1_i64 : tensor<2xi64> + %2296 = stablehlo.dynamic_reshape %2266, %from_elements_798 : (tensor, tensor<2xi64>) -> tensor + %dim_799 = tensor.dim %2263, %c0 : tensor + %2297 = arith.index_cast %dim_799 : index to i64 + %from_elements_800 = tensor.from_elements %2297, %c1_i64 : tensor<2xi64> + %2298 = stablehlo.dynamic_reshape %2263, %from_elements_800 : (tensor, tensor<2xi64>) -> tensor + %2299 = stablehlo.concatenate %2296, %2298, dim = 1 : (tensor, tensor) -> tensor + %2300 = "stablehlo.gather"(%2042, %2299) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2301 = shape.shape_of %2294 : tensor -> tensor<2xindex> + %2302 = shape.shape_of %2300 : tensor -> tensor<2xindex> + %2303 = shape.cstr_broadcastable %2301, %2302 : tensor<2xindex>, tensor<2xindex> + %2304 = shape.assuming %2303 -> (tensor) { + %19688 = shape.broadcast %2301, %2302 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2294, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2300, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2305 = shape.shape_of %2304 : tensor -> tensor<2xindex> + %2306 = stablehlo.dynamic_broadcast_in_dim %2304, %2305, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2307 = stablehlo.dynamic_broadcast_in_dim %213, %2305, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2308 = stablehlo.multiply %2306, %2307 : tensor + %dim_801 = tensor.dim %2268, %c0 : tensor + %2309 = arith.index_cast %dim_801 : index to i64 + %dim_802 = tensor.dim %2304, %c0 : tensor + %2310 = arith.index_cast %dim_802 : index to i64 + %2311 = arith.maxsi %2309, %2310 : i64 + %2312 = arith.index_cast %2311 : i64 to index + %from_elements_803 = tensor.from_elements %2312, %c4096 : tensor<2xindex> + %2313 = stablehlo.dynamic_broadcast_in_dim %2268, %from_elements_803, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_804 = tensor.dim %2313, %c0 : tensor + %2314 = arith.index_cast %dim_804 : index to i64 + %from_elements_805 = tensor.from_elements %2314, %c4096_i64 : tensor<2xi64> + %2315 = stablehlo.real_dynamic_slice %2308, %c_22, %from_elements_805, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_806 = tensor.from_elements %2314, %c4096_i64, %c1_i64 : tensor<3xi64> + %2316 = stablehlo.dynamic_reshape %2313, %from_elements_806 : (tensor, tensor<3xi64>) -> tensor + %2317 = stablehlo.dynamic_iota %from_elements_806, dim = 1 : (tensor<3xi64>) -> tensor + %2318 = stablehlo.concatenate %2316, %2317, dim = 2 : (tensor, tensor) -> tensor + %2319 = "stablehlo.scatter"(%2256, %2318, %2315) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2320 = stablehlo.slice %2002 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2321 = stablehlo.reshape %2320 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2322 = stablehlo.custom_call @byteir.non_zero(%2321) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_807 = tensor.dim %2322, %c0 : tensor + %2323 = arith.index_cast %dim_807 : index to i64 + %from_elements_808 = tensor.from_elements %2323, %c1_i64 : tensor<2xi64> + %2324 = stablehlo.real_dynamic_slice %2322, %c_22, %from_elements_808, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_809 = tensor.dim %2324, %c0 : tensor + %2325 = arith.index_cast %dim_809 : index to i64 + %from_elements_810 = tensor.from_elements %2325 : tensor<1xi64> + %2326 = stablehlo.dynamic_reshape %2324, %from_elements_810 : (tensor, tensor<1xi64>) -> tensor + %from_elements_811 = tensor.from_elements %2323, %c2_i64 : tensor<2xi64> + %2327 = stablehlo.real_dynamic_slice %2322, %c_24, %from_elements_811, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_812 = tensor.dim %2327, %c0 : tensor + %2328 = arith.index_cast %dim_812 : index to i64 + %from_elements_813 = tensor.from_elements %2328 : tensor<1xi64> + %2329 = stablehlo.dynamic_reshape %2327, %from_elements_813 : (tensor, tensor<1xi64>) -> tensor + %dim_814 = tensor.dim %2329, %c0 : tensor + %2330 = arith.index_cast %dim_814 : index to i64 + %from_elements_815 = tensor.from_elements %2330, %c1_i64 : tensor<2xi64> + %2331 = stablehlo.dynamic_reshape %2329, %from_elements_815 : (tensor, tensor<2xi64>) -> tensor + %dim_816 = tensor.dim %2331, %c0 : tensor + %2332 = arith.index_cast %dim_816 : index to i64 + %from_elements_817 = tensor.from_elements %c1_i64, %2332, %c4096_i64 : tensor<3xi64> + %2333 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_817, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_818 = tensor.dim %2333, %c1 : tensor<1x?x4096xi64> + %2334 = arith.index_cast %dim_818 : index to i64 + %from_elements_819 = tensor.from_elements %c1_i64, %2334, %c4096_i64, %c1_i64 : tensor<4xi64> + %2335 = stablehlo.dynamic_reshape %2333, %from_elements_819 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2336 = stablehlo.dynamic_broadcast_in_dim %2331, %from_elements_817, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_820 = tensor.dim %2336, %c1 : tensor<1x?x4096xi64> + %2337 = arith.index_cast %dim_820 : index to i64 + %from_elements_821 = tensor.from_elements %c1_i64, %2337, %c4096_i64, %c1_i64 : tensor<4xi64> + %2338 = stablehlo.dynamic_reshape %2336, %from_elements_821 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2339 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_817, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_822 = tensor.dim %2339, %c1 : tensor<1x?x4096xi64> + %2340 = arith.index_cast %dim_822 : index to i64 + %from_elements_823 = tensor.from_elements %c1_i64, %2340, %c4096_i64, %c1_i64 : tensor<4xi64> + %2341 = stablehlo.dynamic_reshape %2339, %from_elements_823 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2342 = stablehlo.concatenate %2335, %2338, %2341, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2343 = "stablehlo.gather"(%2013, %2342) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2344 = shape.shape_of %2343 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2345 = shape.num_elements %2344 : tensor<3xindex> -> index + %2346 = stablehlo.compute_reshape_shape %2345, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2347 = stablehlo.dynamic_reshape %2343, %2346 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2348 = stablehlo.dot %2347, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2349 = stablehlo.logistic %2348 : tensor + %2350 = shape.shape_of %2349 : tensor -> tensor<2xindex> + %2351 = shape.shape_of %2348 : tensor -> tensor<2xindex> + %2352 = shape.cstr_broadcastable %2350, %2351 : tensor<2xindex>, tensor<2xindex> + %2353 = shape.assuming %2352 -> (tensor) { + %19688 = shape.broadcast %2350, %2351 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2349, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2348, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2354 = shape.shape_of %2353 : tensor -> tensor<2xindex> + %2355 = shape.cstr_broadcastable %2354, %2351 : tensor<2xindex>, tensor<2xindex> + %2356 = shape.assuming %2355 -> (tensor) { + %19688 = shape.broadcast %2354, %2351 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2353, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2348, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2357 = stablehlo.dot %2356, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_824 = tensor.dim %2329, %c0 : tensor + %2358 = arith.index_cast %dim_824 : index to i64 + %from_elements_825 = tensor.from_elements %2358, %c1_i64 : tensor<2xi64> + %2359 = stablehlo.dynamic_reshape %2329, %from_elements_825 : (tensor, tensor<2xi64>) -> tensor + %dim_826 = tensor.dim %2326, %c0 : tensor + %2360 = arith.index_cast %dim_826 : index to i64 + %from_elements_827 = tensor.from_elements %2360, %c1_i64 : tensor<2xi64> + %2361 = stablehlo.dynamic_reshape %2326, %from_elements_827 : (tensor, tensor<2xi64>) -> tensor + %2362 = stablehlo.concatenate %2359, %2361, dim = 1 : (tensor, tensor) -> tensor + %2363 = "stablehlo.gather"(%2042, %2362) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2364 = shape.shape_of %2357 : tensor -> tensor<2xindex> + %2365 = shape.shape_of %2363 : tensor -> tensor<2xindex> + %2366 = shape.cstr_broadcastable %2364, %2365 : tensor<2xindex>, tensor<2xindex> + %2367 = shape.assuming %2366 -> (tensor) { + %19688 = shape.broadcast %2364, %2365 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2357, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2363, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2368 = shape.shape_of %2367 : tensor -> tensor<2xindex> + %2369 = stablehlo.dynamic_broadcast_in_dim %2367, %2368, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2370 = stablehlo.dynamic_broadcast_in_dim %213, %2368, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2371 = stablehlo.multiply %2369, %2370 : tensor + %dim_828 = tensor.dim %2331, %c0 : tensor + %2372 = arith.index_cast %dim_828 : index to i64 + %dim_829 = tensor.dim %2367, %c0 : tensor + %2373 = arith.index_cast %dim_829 : index to i64 + %2374 = arith.maxsi %2372, %2373 : i64 + %2375 = arith.index_cast %2374 : i64 to index + %from_elements_830 = tensor.from_elements %2375, %c4096 : tensor<2xindex> + %2376 = stablehlo.dynamic_broadcast_in_dim %2331, %from_elements_830, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_831 = tensor.dim %2376, %c0 : tensor + %2377 = arith.index_cast %dim_831 : index to i64 + %from_elements_832 = tensor.from_elements %2377, %c4096_i64 : tensor<2xi64> + %2378 = stablehlo.real_dynamic_slice %2371, %c_22, %from_elements_832, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_833 = tensor.from_elements %2377, %c4096_i64, %c1_i64 : tensor<3xi64> + %2379 = stablehlo.dynamic_reshape %2376, %from_elements_833 : (tensor, tensor<3xi64>) -> tensor + %2380 = stablehlo.dynamic_iota %from_elements_833, dim = 1 : (tensor<3xi64>) -> tensor + %2381 = stablehlo.concatenate %2379, %2380, dim = 2 : (tensor, tensor) -> tensor + %2382 = "stablehlo.scatter"(%2319, %2381, %2378) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2383 = stablehlo.slice %2002 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2384 = stablehlo.reshape %2383 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2385 = stablehlo.custom_call @byteir.non_zero(%2384) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_834 = tensor.dim %2385, %c0 : tensor + %2386 = arith.index_cast %dim_834 : index to i64 + %from_elements_835 = tensor.from_elements %2386, %c1_i64 : tensor<2xi64> + %2387 = stablehlo.real_dynamic_slice %2385, %c_22, %from_elements_835, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_836 = tensor.dim %2387, %c0 : tensor + %2388 = arith.index_cast %dim_836 : index to i64 + %from_elements_837 = tensor.from_elements %2388 : tensor<1xi64> + %2389 = stablehlo.dynamic_reshape %2387, %from_elements_837 : (tensor, tensor<1xi64>) -> tensor + %from_elements_838 = tensor.from_elements %2386, %c2_i64 : tensor<2xi64> + %2390 = stablehlo.real_dynamic_slice %2385, %c_24, %from_elements_838, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_839 = tensor.dim %2390, %c0 : tensor + %2391 = arith.index_cast %dim_839 : index to i64 + %from_elements_840 = tensor.from_elements %2391 : tensor<1xi64> + %2392 = stablehlo.dynamic_reshape %2390, %from_elements_840 : (tensor, tensor<1xi64>) -> tensor + %dim_841 = tensor.dim %2392, %c0 : tensor + %2393 = arith.index_cast %dim_841 : index to i64 + %from_elements_842 = tensor.from_elements %2393, %c1_i64 : tensor<2xi64> + %2394 = stablehlo.dynamic_reshape %2392, %from_elements_842 : (tensor, tensor<2xi64>) -> tensor + %dim_843 = tensor.dim %2394, %c0 : tensor + %2395 = arith.index_cast %dim_843 : index to i64 + %from_elements_844 = tensor.from_elements %c1_i64, %2395, %c4096_i64 : tensor<3xi64> + %2396 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_844, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_845 = tensor.dim %2396, %c1 : tensor<1x?x4096xi64> + %2397 = arith.index_cast %dim_845 : index to i64 + %from_elements_846 = tensor.from_elements %c1_i64, %2397, %c4096_i64, %c1_i64 : tensor<4xi64> + %2398 = stablehlo.dynamic_reshape %2396, %from_elements_846 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2399 = stablehlo.dynamic_broadcast_in_dim %2394, %from_elements_844, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_847 = tensor.dim %2399, %c1 : tensor<1x?x4096xi64> + %2400 = arith.index_cast %dim_847 : index to i64 + %from_elements_848 = tensor.from_elements %c1_i64, %2400, %c4096_i64, %c1_i64 : tensor<4xi64> + %2401 = stablehlo.dynamic_reshape %2399, %from_elements_848 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2402 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_844, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_849 = tensor.dim %2402, %c1 : tensor<1x?x4096xi64> + %2403 = arith.index_cast %dim_849 : index to i64 + %from_elements_850 = tensor.from_elements %c1_i64, %2403, %c4096_i64, %c1_i64 : tensor<4xi64> + %2404 = stablehlo.dynamic_reshape %2402, %from_elements_850 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2405 = stablehlo.concatenate %2398, %2401, %2404, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2406 = "stablehlo.gather"(%2013, %2405) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2407 = shape.shape_of %2406 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2408 = shape.num_elements %2407 : tensor<3xindex> -> index + %2409 = stablehlo.compute_reshape_shape %2408, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2410 = stablehlo.dynamic_reshape %2406, %2409 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2411 = stablehlo.dot %2410, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2412 = stablehlo.logistic %2411 : tensor + %2413 = shape.shape_of %2412 : tensor -> tensor<2xindex> + %2414 = shape.shape_of %2411 : tensor -> tensor<2xindex> + %2415 = shape.cstr_broadcastable %2413, %2414 : tensor<2xindex>, tensor<2xindex> + %2416 = shape.assuming %2415 -> (tensor) { + %19688 = shape.broadcast %2413, %2414 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2412, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2411, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2417 = shape.shape_of %2416 : tensor -> tensor<2xindex> + %2418 = shape.cstr_broadcastable %2417, %2414 : tensor<2xindex>, tensor<2xindex> + %2419 = shape.assuming %2418 -> (tensor) { + %19688 = shape.broadcast %2417, %2414 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2416, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2411, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2420 = stablehlo.dot %2419, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_851 = tensor.dim %2392, %c0 : tensor + %2421 = arith.index_cast %dim_851 : index to i64 + %from_elements_852 = tensor.from_elements %2421, %c1_i64 : tensor<2xi64> + %2422 = stablehlo.dynamic_reshape %2392, %from_elements_852 : (tensor, tensor<2xi64>) -> tensor + %dim_853 = tensor.dim %2389, %c0 : tensor + %2423 = arith.index_cast %dim_853 : index to i64 + %from_elements_854 = tensor.from_elements %2423, %c1_i64 : tensor<2xi64> + %2424 = stablehlo.dynamic_reshape %2389, %from_elements_854 : (tensor, tensor<2xi64>) -> tensor + %2425 = stablehlo.concatenate %2422, %2424, dim = 1 : (tensor, tensor) -> tensor + %2426 = "stablehlo.gather"(%2042, %2425) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2427 = shape.shape_of %2420 : tensor -> tensor<2xindex> + %2428 = shape.shape_of %2426 : tensor -> tensor<2xindex> + %2429 = shape.cstr_broadcastable %2427, %2428 : tensor<2xindex>, tensor<2xindex> + %2430 = shape.assuming %2429 -> (tensor) { + %19688 = shape.broadcast %2427, %2428 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2420, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2426, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2431 = shape.shape_of %2430 : tensor -> tensor<2xindex> + %2432 = stablehlo.dynamic_broadcast_in_dim %2430, %2431, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2433 = stablehlo.dynamic_broadcast_in_dim %213, %2431, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2434 = stablehlo.multiply %2432, %2433 : tensor + %dim_855 = tensor.dim %2394, %c0 : tensor + %2435 = arith.index_cast %dim_855 : index to i64 + %dim_856 = tensor.dim %2430, %c0 : tensor + %2436 = arith.index_cast %dim_856 : index to i64 + %2437 = arith.maxsi %2435, %2436 : i64 + %2438 = arith.index_cast %2437 : i64 to index + %from_elements_857 = tensor.from_elements %2438, %c4096 : tensor<2xindex> + %2439 = stablehlo.dynamic_broadcast_in_dim %2394, %from_elements_857, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_858 = tensor.dim %2439, %c0 : tensor + %2440 = arith.index_cast %dim_858 : index to i64 + %from_elements_859 = tensor.from_elements %2440, %c4096_i64 : tensor<2xi64> + %2441 = stablehlo.real_dynamic_slice %2434, %c_22, %from_elements_859, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_860 = tensor.from_elements %2440, %c4096_i64, %c1_i64 : tensor<3xi64> + %2442 = stablehlo.dynamic_reshape %2439, %from_elements_860 : (tensor, tensor<3xi64>) -> tensor + %2443 = stablehlo.dynamic_iota %from_elements_860, dim = 1 : (tensor<3xi64>) -> tensor + %2444 = stablehlo.concatenate %2442, %2443, dim = 2 : (tensor, tensor) -> tensor + %2445 = "stablehlo.scatter"(%2382, %2444, %2441) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2446 = stablehlo.slice %2002 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2447 = stablehlo.reshape %2446 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2448 = stablehlo.custom_call @byteir.non_zero(%2447) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_861 = tensor.dim %2448, %c0 : tensor + %2449 = arith.index_cast %dim_861 : index to i64 + %from_elements_862 = tensor.from_elements %2449, %c1_i64 : tensor<2xi64> + %2450 = stablehlo.real_dynamic_slice %2448, %c_22, %from_elements_862, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_863 = tensor.dim %2450, %c0 : tensor + %2451 = arith.index_cast %dim_863 : index to i64 + %from_elements_864 = tensor.from_elements %2451 : tensor<1xi64> + %2452 = stablehlo.dynamic_reshape %2450, %from_elements_864 : (tensor, tensor<1xi64>) -> tensor + %from_elements_865 = tensor.from_elements %2449, %c2_i64 : tensor<2xi64> + %2453 = stablehlo.real_dynamic_slice %2448, %c_24, %from_elements_865, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_866 = tensor.dim %2453, %c0 : tensor + %2454 = arith.index_cast %dim_866 : index to i64 + %from_elements_867 = tensor.from_elements %2454 : tensor<1xi64> + %2455 = stablehlo.dynamic_reshape %2453, %from_elements_867 : (tensor, tensor<1xi64>) -> tensor + %dim_868 = tensor.dim %2455, %c0 : tensor + %2456 = arith.index_cast %dim_868 : index to i64 + %from_elements_869 = tensor.from_elements %2456, %c1_i64 : tensor<2xi64> + %2457 = stablehlo.dynamic_reshape %2455, %from_elements_869 : (tensor, tensor<2xi64>) -> tensor + %dim_870 = tensor.dim %2457, %c0 : tensor + %2458 = arith.index_cast %dim_870 : index to i64 + %from_elements_871 = tensor.from_elements %c1_i64, %2458, %c4096_i64 : tensor<3xi64> + %2459 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_871, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_872 = tensor.dim %2459, %c1 : tensor<1x?x4096xi64> + %2460 = arith.index_cast %dim_872 : index to i64 + %from_elements_873 = tensor.from_elements %c1_i64, %2460, %c4096_i64, %c1_i64 : tensor<4xi64> + %2461 = stablehlo.dynamic_reshape %2459, %from_elements_873 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2462 = stablehlo.dynamic_broadcast_in_dim %2457, %from_elements_871, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_874 = tensor.dim %2462, %c1 : tensor<1x?x4096xi64> + %2463 = arith.index_cast %dim_874 : index to i64 + %from_elements_875 = tensor.from_elements %c1_i64, %2463, %c4096_i64, %c1_i64 : tensor<4xi64> + %2464 = stablehlo.dynamic_reshape %2462, %from_elements_875 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2465 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_871, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_876 = tensor.dim %2465, %c1 : tensor<1x?x4096xi64> + %2466 = arith.index_cast %dim_876 : index to i64 + %from_elements_877 = tensor.from_elements %c1_i64, %2466, %c4096_i64, %c1_i64 : tensor<4xi64> + %2467 = stablehlo.dynamic_reshape %2465, %from_elements_877 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2468 = stablehlo.concatenate %2461, %2464, %2467, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2469 = "stablehlo.gather"(%2013, %2468) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2470 = shape.shape_of %2469 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2471 = shape.num_elements %2470 : tensor<3xindex> -> index + %2472 = stablehlo.compute_reshape_shape %2471, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2473 = stablehlo.dynamic_reshape %2469, %2472 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2474 = stablehlo.dot %2473, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2475 = stablehlo.logistic %2474 : tensor + %2476 = shape.shape_of %2475 : tensor -> tensor<2xindex> + %2477 = shape.shape_of %2474 : tensor -> tensor<2xindex> + %2478 = shape.cstr_broadcastable %2476, %2477 : tensor<2xindex>, tensor<2xindex> + %2479 = shape.assuming %2478 -> (tensor) { + %19688 = shape.broadcast %2476, %2477 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2475, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2474, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2480 = shape.shape_of %2479 : tensor -> tensor<2xindex> + %2481 = shape.cstr_broadcastable %2480, %2477 : tensor<2xindex>, tensor<2xindex> + %2482 = shape.assuming %2481 -> (tensor) { + %19688 = shape.broadcast %2480, %2477 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2479, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2474, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2483 = stablehlo.dot %2482, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_878 = tensor.dim %2455, %c0 : tensor + %2484 = arith.index_cast %dim_878 : index to i64 + %from_elements_879 = tensor.from_elements %2484, %c1_i64 : tensor<2xi64> + %2485 = stablehlo.dynamic_reshape %2455, %from_elements_879 : (tensor, tensor<2xi64>) -> tensor + %dim_880 = tensor.dim %2452, %c0 : tensor + %2486 = arith.index_cast %dim_880 : index to i64 + %from_elements_881 = tensor.from_elements %2486, %c1_i64 : tensor<2xi64> + %2487 = stablehlo.dynamic_reshape %2452, %from_elements_881 : (tensor, tensor<2xi64>) -> tensor + %2488 = stablehlo.concatenate %2485, %2487, dim = 1 : (tensor, tensor) -> tensor + %2489 = "stablehlo.gather"(%2042, %2488) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2490 = shape.shape_of %2483 : tensor -> tensor<2xindex> + %2491 = shape.shape_of %2489 : tensor -> tensor<2xindex> + %2492 = shape.cstr_broadcastable %2490, %2491 : tensor<2xindex>, tensor<2xindex> + %2493 = shape.assuming %2492 -> (tensor) { + %19688 = shape.broadcast %2490, %2491 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2483, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2489, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2494 = shape.shape_of %2493 : tensor -> tensor<2xindex> + %2495 = stablehlo.dynamic_broadcast_in_dim %2493, %2494, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2496 = stablehlo.dynamic_broadcast_in_dim %213, %2494, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2497 = stablehlo.multiply %2495, %2496 : tensor + %dim_882 = tensor.dim %2457, %c0 : tensor + %2498 = arith.index_cast %dim_882 : index to i64 + %dim_883 = tensor.dim %2493, %c0 : tensor + %2499 = arith.index_cast %dim_883 : index to i64 + %2500 = arith.maxsi %2498, %2499 : i64 + %2501 = arith.index_cast %2500 : i64 to index + %from_elements_884 = tensor.from_elements %2501, %c4096 : tensor<2xindex> + %2502 = stablehlo.dynamic_broadcast_in_dim %2457, %from_elements_884, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_885 = tensor.dim %2502, %c0 : tensor + %2503 = arith.index_cast %dim_885 : index to i64 + %from_elements_886 = tensor.from_elements %2503, %c4096_i64 : tensor<2xi64> + %2504 = stablehlo.real_dynamic_slice %2497, %c_22, %from_elements_886, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_887 = tensor.from_elements %2503, %c4096_i64, %c1_i64 : tensor<3xi64> + %2505 = stablehlo.dynamic_reshape %2502, %from_elements_887 : (tensor, tensor<3xi64>) -> tensor + %2506 = stablehlo.dynamic_iota %from_elements_887, dim = 1 : (tensor<3xi64>) -> tensor + %2507 = stablehlo.concatenate %2505, %2506, dim = 2 : (tensor, tensor) -> tensor + %2508 = "stablehlo.scatter"(%2445, %2507, %2504) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2509 = stablehlo.reshape %2508 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %2510 = stablehlo.add %1975, %2509 : tensor<3x1x4096xf32> + %2511 = stablehlo.broadcast_in_dim %2510, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %2512 = stablehlo.power %2511, %15 : tensor<3x1x4096xf32> + %2513 = stablehlo.reduce(%2512 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %2514 = stablehlo.reshape %2513 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %2515 = stablehlo.broadcast_in_dim %2514, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %2516 = stablehlo.divide %2515, %21 : tensor<3x1x1xf32> + %2517 = stablehlo.broadcast_in_dim %2516, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %2518 = stablehlo.add %2517, %25 : tensor<3x1x1xf32> + %2519 = stablehlo.rsqrt %2518 : tensor<3x1x1xf32> + %2520 = stablehlo.broadcast_in_dim %2519, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %2521 = stablehlo.multiply %2511, %2520 : tensor<3x1x4096xf32> + %2522 = stablehlo.broadcast_in_dim %2521, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %2523 = stablehlo.multiply %2522, %31 : tensor<3x1x4096xf32> + %2524 = stablehlo.reshape %2523 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %2525 = stablehlo.dot %2524, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %2526 = stablehlo.reshape %2525 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %2527 = stablehlo.dot %2524, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %2528 = stablehlo.reshape %2527 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %2529 = stablehlo.reshape %2526 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %2530 = stablehlo.transpose %2529, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %2531 = stablehlo.reshape %2528 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %2532 = stablehlo.transpose %2531, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %2533 = stablehlo.slice %arg8 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %2534 = stablehlo.slice %arg9 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %2535 = "stablehlo.gather"(%2533, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %2536 = stablehlo.reshape %2535 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %2537 = "stablehlo.gather"(%2534, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %2538 = stablehlo.reshape %2537 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %2539 = stablehlo.broadcast_in_dim %2530, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %2540 = stablehlo.broadcast_in_dim %2536, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %2541 = stablehlo.multiply %2539, %2540 : tensor<3x32x1x128xf32> + %2542 = stablehlo.slice %2530 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %2543 = stablehlo.slice %2530 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %2544 = stablehlo.negate %2543 : tensor<3x32x1x64xf32> + %2545 = stablehlo.concatenate %2544, %2542, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %2546 = stablehlo.broadcast_in_dim %2545, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %2547 = stablehlo.broadcast_in_dim %2538, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %2548 = stablehlo.multiply %2546, %2547 : tensor<3x32x1x128xf32> + %2549 = stablehlo.add %2541, %2548 : tensor<3x32x1x128xf32> + %2550 = stablehlo.broadcast_in_dim %2532, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %2551 = stablehlo.broadcast_in_dim %2536, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %2552 = stablehlo.multiply %2550, %2551 : tensor<3x8x1x128xf32> + %2553 = stablehlo.slice %2532 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %2554 = stablehlo.slice %2532 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %2555 = stablehlo.negate %2554 : tensor<3x8x1x64xf32> + %2556 = stablehlo.concatenate %2555, %2553, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %2557 = stablehlo.broadcast_in_dim %2556, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %2558 = stablehlo.broadcast_in_dim %2538, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %2559 = stablehlo.multiply %2557, %2558 : tensor<3x8x1x128xf32> + %2560 = stablehlo.add %2552, %2559 : tensor<3x8x1x128xf32> + %2561 = stablehlo.concatenate %arg73, %2560, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %2562 = stablehlo.concatenate %arg74, %2532, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %2563 = stablehlo.reshape %2561 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %2564 = stablehlo.broadcast_in_dim %2563, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %2565 = stablehlo.reshape %2564 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %2566 = stablehlo.reshape %2562 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %2567 = stablehlo.broadcast_in_dim %2566, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %2568 = stablehlo.reshape %2567 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %2569 = stablehlo.transpose %2565, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %2570 = stablehlo.reshape %2549 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %2571 = stablehlo.reshape %2569 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %2572 = stablehlo.broadcast_in_dim %2571, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %2573 = stablehlo.dot_general %2570, %2572, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %2574 = stablehlo.reshape %2573 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %2575 = stablehlo.broadcast_in_dim %2574, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %2576 = stablehlo.divide %2575, %89 : tensor<3x32x1x8xf32> + %2577 = stablehlo.custom_call @byteir.softmax(%2576) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %2578 = stablehlo.reshape %2577 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %2579 = stablehlo.reshape %2568 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %2580 = stablehlo.broadcast_in_dim %2579, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %2581 = stablehlo.dot_general %2578, %2580, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %2582 = stablehlo.reshape %2581 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %2583 = stablehlo.transpose %2582, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %2584 = stablehlo.reshape %2583 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %2585 = stablehlo.reshape %2584 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %2586 = stablehlo.dot %2585, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %2587 = stablehlo.reshape %2586 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %2588 = stablehlo.add %2510, %2587 : tensor<3x1x4096xf32> + %2589 = stablehlo.broadcast_in_dim %2588, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %2590 = stablehlo.power %2589, %15 : tensor<3x1x4096xf32> + %2591 = stablehlo.reduce(%2590 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %2592 = stablehlo.reshape %2591 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %2593 = stablehlo.broadcast_in_dim %2592, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %2594 = stablehlo.divide %2593, %21 : tensor<3x1x1xf32> + %2595 = stablehlo.broadcast_in_dim %2594, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %2596 = stablehlo.add %2595, %25 : tensor<3x1x1xf32> + %2597 = stablehlo.rsqrt %2596 : tensor<3x1x1xf32> + %2598 = stablehlo.broadcast_in_dim %2597, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %2599 = stablehlo.multiply %2589, %2598 : tensor<3x1x4096xf32> + %2600 = stablehlo.broadcast_in_dim %2599, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %2601 = stablehlo.multiply %2600, %31 : tensor<3x1x4096xf32> + %2602 = stablehlo.reshape %2601 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %2603 = stablehlo.dot %2602, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %2604 = stablehlo.custom_call @byteir.softmax(%2603) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %2605:2 = stablehlo.custom_call @byteir.top_k(%2604) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %2606 = stablehlo.reduce(%2605#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %2607 = stablehlo.reshape %2606 : (tensor<3xf32>) -> tensor<3x1xf32> + %2608 = stablehlo.broadcast_in_dim %2605#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %2609 = stablehlo.broadcast_in_dim %2607, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %2610 = stablehlo.divide %2608, %2609 : tensor<3x2xf32> + %2611 = stablehlo.reshape %2605#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %2612 = stablehlo.broadcast_in_dim %2611, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %2613 = stablehlo.compare EQ, %2612, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %2614 = stablehlo.convert %2613 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %2615 = stablehlo.transpose %2614, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %2616 = stablehlo.slice %2615 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2617 = stablehlo.reshape %2616 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2618 = stablehlo.custom_call @byteir.non_zero(%2617) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_888 = tensor.dim %2618, %c0 : tensor + %2619 = arith.index_cast %dim_888 : index to i64 + %from_elements_889 = tensor.from_elements %2619, %c1_i64 : tensor<2xi64> + %2620 = stablehlo.real_dynamic_slice %2618, %c_22, %from_elements_889, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_890 = tensor.dim %2620, %c0 : tensor + %2621 = arith.index_cast %dim_890 : index to i64 + %from_elements_891 = tensor.from_elements %2621 : tensor<1xi64> + %2622 = stablehlo.dynamic_reshape %2620, %from_elements_891 : (tensor, tensor<1xi64>) -> tensor + %from_elements_892 = tensor.from_elements %2619, %c2_i64 : tensor<2xi64> + %2623 = stablehlo.real_dynamic_slice %2618, %c_24, %from_elements_892, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_893 = tensor.dim %2623, %c0 : tensor + %2624 = arith.index_cast %dim_893 : index to i64 + %from_elements_894 = tensor.from_elements %2624 : tensor<1xi64> + %2625 = stablehlo.dynamic_reshape %2623, %from_elements_894 : (tensor, tensor<1xi64>) -> tensor + %2626 = stablehlo.reshape %2602 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_895 = tensor.dim %2625, %c0 : tensor + %2627 = arith.index_cast %dim_895 : index to i64 + %from_elements_896 = tensor.from_elements %2627, %c1_i64 : tensor<2xi64> + %2628 = stablehlo.dynamic_reshape %2625, %from_elements_896 : (tensor, tensor<2xi64>) -> tensor + %dim_897 = tensor.dim %2628, %c0 : tensor + %2629 = arith.index_cast %dim_897 : index to i64 + %from_elements_898 = tensor.from_elements %c1_i64, %2629, %c4096_i64 : tensor<3xi64> + %2630 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_898, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_899 = tensor.dim %2630, %c1 : tensor<1x?x4096xi64> + %2631 = arith.index_cast %dim_899 : index to i64 + %from_elements_900 = tensor.from_elements %c1_i64, %2631, %c4096_i64, %c1_i64 : tensor<4xi64> + %2632 = stablehlo.dynamic_reshape %2630, %from_elements_900 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2633 = stablehlo.dynamic_broadcast_in_dim %2628, %from_elements_898, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_901 = tensor.dim %2633, %c1 : tensor<1x?x4096xi64> + %2634 = arith.index_cast %dim_901 : index to i64 + %from_elements_902 = tensor.from_elements %c1_i64, %2634, %c4096_i64, %c1_i64 : tensor<4xi64> + %2635 = stablehlo.dynamic_reshape %2633, %from_elements_902 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2636 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_898, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_903 = tensor.dim %2636, %c1 : tensor<1x?x4096xi64> + %2637 = arith.index_cast %dim_903 : index to i64 + %from_elements_904 = tensor.from_elements %c1_i64, %2637, %c4096_i64, %c1_i64 : tensor<4xi64> + %2638 = stablehlo.dynamic_reshape %2636, %from_elements_904 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2639 = stablehlo.concatenate %2632, %2635, %2638, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2640 = "stablehlo.gather"(%2626, %2639) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2641 = shape.shape_of %2640 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2642 = shape.num_elements %2641 : tensor<3xindex> -> index + %2643 = stablehlo.compute_reshape_shape %2642, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2644 = stablehlo.dynamic_reshape %2640, %2643 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2645 = stablehlo.dot %2644, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2646 = stablehlo.logistic %2645 : tensor + %2647 = shape.shape_of %2646 : tensor -> tensor<2xindex> + %2648 = shape.shape_of %2645 : tensor -> tensor<2xindex> + %2649 = shape.cstr_broadcastable %2647, %2648 : tensor<2xindex>, tensor<2xindex> + %2650 = shape.assuming %2649 -> (tensor) { + %19688 = shape.broadcast %2647, %2648 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2646, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2645, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2651 = shape.shape_of %2650 : tensor -> tensor<2xindex> + %2652 = shape.cstr_broadcastable %2651, %2648 : tensor<2xindex>, tensor<2xindex> + %2653 = shape.assuming %2652 -> (tensor) { + %19688 = shape.broadcast %2651, %2648 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2650, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2645, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2654 = stablehlo.dot %2653, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %2655 = stablehlo.reshape %2610 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_905 = tensor.dim %2625, %c0 : tensor + %2656 = arith.index_cast %dim_905 : index to i64 + %from_elements_906 = tensor.from_elements %2656, %c1_i64 : tensor<2xi64> + %2657 = stablehlo.dynamic_reshape %2625, %from_elements_906 : (tensor, tensor<2xi64>) -> tensor + %dim_907 = tensor.dim %2622, %c0 : tensor + %2658 = arith.index_cast %dim_907 : index to i64 + %from_elements_908 = tensor.from_elements %2658, %c1_i64 : tensor<2xi64> + %2659 = stablehlo.dynamic_reshape %2622, %from_elements_908 : (tensor, tensor<2xi64>) -> tensor + %2660 = stablehlo.concatenate %2657, %2659, dim = 1 : (tensor, tensor) -> tensor + %2661 = "stablehlo.gather"(%2655, %2660) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2662 = shape.shape_of %2654 : tensor -> tensor<2xindex> + %2663 = shape.shape_of %2661 : tensor -> tensor<2xindex> + %2664 = shape.cstr_broadcastable %2662, %2663 : tensor<2xindex>, tensor<2xindex> + %2665 = shape.assuming %2664 -> (tensor) { + %19688 = shape.broadcast %2662, %2663 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2654, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2661, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2666 = shape.shape_of %2665 : tensor -> tensor<2xindex> + %2667 = stablehlo.dynamic_broadcast_in_dim %2665, %2666, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2668 = stablehlo.dynamic_broadcast_in_dim %213, %2666, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2669 = stablehlo.multiply %2667, %2668 : tensor + %dim_909 = tensor.dim %2628, %c0 : tensor + %2670 = arith.index_cast %dim_909 : index to i64 + %dim_910 = tensor.dim %2665, %c0 : tensor + %2671 = arith.index_cast %dim_910 : index to i64 + %2672 = arith.maxsi %2670, %2671 : i64 + %2673 = arith.index_cast %2672 : i64 to index + %from_elements_911 = tensor.from_elements %2673, %c4096 : tensor<2xindex> + %2674 = stablehlo.dynamic_broadcast_in_dim %2628, %from_elements_911, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_912 = tensor.dim %2674, %c0 : tensor + %2675 = arith.index_cast %dim_912 : index to i64 + %from_elements_913 = tensor.from_elements %2675, %c4096_i64 : tensor<2xi64> + %2676 = stablehlo.real_dynamic_slice %2669, %c_22, %from_elements_913, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_914 = tensor.from_elements %2675, %c4096_i64, %c1_i64 : tensor<3xi64> + %2677 = stablehlo.dynamic_reshape %2674, %from_elements_914 : (tensor, tensor<3xi64>) -> tensor + %2678 = stablehlo.dynamic_iota %from_elements_914, dim = 1 : (tensor<3xi64>) -> tensor + %2679 = stablehlo.concatenate %2677, %2678, dim = 2 : (tensor, tensor) -> tensor + %2680 = "stablehlo.scatter"(%cst_2, %2679, %2676) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2681 = stablehlo.slice %2615 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2682 = stablehlo.reshape %2681 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2683 = stablehlo.custom_call @byteir.non_zero(%2682) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_915 = tensor.dim %2683, %c0 : tensor + %2684 = arith.index_cast %dim_915 : index to i64 + %from_elements_916 = tensor.from_elements %2684, %c1_i64 : tensor<2xi64> + %2685 = stablehlo.real_dynamic_slice %2683, %c_22, %from_elements_916, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_917 = tensor.dim %2685, %c0 : tensor + %2686 = arith.index_cast %dim_917 : index to i64 + %from_elements_918 = tensor.from_elements %2686 : tensor<1xi64> + %2687 = stablehlo.dynamic_reshape %2685, %from_elements_918 : (tensor, tensor<1xi64>) -> tensor + %from_elements_919 = tensor.from_elements %2684, %c2_i64 : tensor<2xi64> + %2688 = stablehlo.real_dynamic_slice %2683, %c_24, %from_elements_919, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_920 = tensor.dim %2688, %c0 : tensor + %2689 = arith.index_cast %dim_920 : index to i64 + %from_elements_921 = tensor.from_elements %2689 : tensor<1xi64> + %2690 = stablehlo.dynamic_reshape %2688, %from_elements_921 : (tensor, tensor<1xi64>) -> tensor + %dim_922 = tensor.dim %2690, %c0 : tensor + %2691 = arith.index_cast %dim_922 : index to i64 + %from_elements_923 = tensor.from_elements %2691, %c1_i64 : tensor<2xi64> + %2692 = stablehlo.dynamic_reshape %2690, %from_elements_923 : (tensor, tensor<2xi64>) -> tensor + %dim_924 = tensor.dim %2692, %c0 : tensor + %2693 = arith.index_cast %dim_924 : index to i64 + %from_elements_925 = tensor.from_elements %c1_i64, %2693, %c4096_i64 : tensor<3xi64> + %2694 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_925, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_926 = tensor.dim %2694, %c1 : tensor<1x?x4096xi64> + %2695 = arith.index_cast %dim_926 : index to i64 + %from_elements_927 = tensor.from_elements %c1_i64, %2695, %c4096_i64, %c1_i64 : tensor<4xi64> + %2696 = stablehlo.dynamic_reshape %2694, %from_elements_927 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2697 = stablehlo.dynamic_broadcast_in_dim %2692, %from_elements_925, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_928 = tensor.dim %2697, %c1 : tensor<1x?x4096xi64> + %2698 = arith.index_cast %dim_928 : index to i64 + %from_elements_929 = tensor.from_elements %c1_i64, %2698, %c4096_i64, %c1_i64 : tensor<4xi64> + %2699 = stablehlo.dynamic_reshape %2697, %from_elements_929 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2700 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_925, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_930 = tensor.dim %2700, %c1 : tensor<1x?x4096xi64> + %2701 = arith.index_cast %dim_930 : index to i64 + %from_elements_931 = tensor.from_elements %c1_i64, %2701, %c4096_i64, %c1_i64 : tensor<4xi64> + %2702 = stablehlo.dynamic_reshape %2700, %from_elements_931 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2703 = stablehlo.concatenate %2696, %2699, %2702, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2704 = "stablehlo.gather"(%2626, %2703) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2705 = shape.shape_of %2704 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2706 = shape.num_elements %2705 : tensor<3xindex> -> index + %2707 = stablehlo.compute_reshape_shape %2706, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2708 = stablehlo.dynamic_reshape %2704, %2707 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2709 = stablehlo.dot %2708, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2710 = stablehlo.logistic %2709 : tensor + %2711 = shape.shape_of %2710 : tensor -> tensor<2xindex> + %2712 = shape.shape_of %2709 : tensor -> tensor<2xindex> + %2713 = shape.cstr_broadcastable %2711, %2712 : tensor<2xindex>, tensor<2xindex> + %2714 = shape.assuming %2713 -> (tensor) { + %19688 = shape.broadcast %2711, %2712 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2710, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2709, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2715 = shape.shape_of %2714 : tensor -> tensor<2xindex> + %2716 = shape.cstr_broadcastable %2715, %2712 : tensor<2xindex>, tensor<2xindex> + %2717 = shape.assuming %2716 -> (tensor) { + %19688 = shape.broadcast %2715, %2712 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2714, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2709, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2718 = stablehlo.dot %2717, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_932 = tensor.dim %2690, %c0 : tensor + %2719 = arith.index_cast %dim_932 : index to i64 + %from_elements_933 = tensor.from_elements %2719, %c1_i64 : tensor<2xi64> + %2720 = stablehlo.dynamic_reshape %2690, %from_elements_933 : (tensor, tensor<2xi64>) -> tensor + %dim_934 = tensor.dim %2687, %c0 : tensor + %2721 = arith.index_cast %dim_934 : index to i64 + %from_elements_935 = tensor.from_elements %2721, %c1_i64 : tensor<2xi64> + %2722 = stablehlo.dynamic_reshape %2687, %from_elements_935 : (tensor, tensor<2xi64>) -> tensor + %2723 = stablehlo.concatenate %2720, %2722, dim = 1 : (tensor, tensor) -> tensor + %2724 = "stablehlo.gather"(%2655, %2723) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2725 = shape.shape_of %2718 : tensor -> tensor<2xindex> + %2726 = shape.shape_of %2724 : tensor -> tensor<2xindex> + %2727 = shape.cstr_broadcastable %2725, %2726 : tensor<2xindex>, tensor<2xindex> + %2728 = shape.assuming %2727 -> (tensor) { + %19688 = shape.broadcast %2725, %2726 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2718, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2724, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2729 = shape.shape_of %2728 : tensor -> tensor<2xindex> + %2730 = stablehlo.dynamic_broadcast_in_dim %2728, %2729, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2731 = stablehlo.dynamic_broadcast_in_dim %213, %2729, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2732 = stablehlo.multiply %2730, %2731 : tensor + %dim_936 = tensor.dim %2692, %c0 : tensor + %2733 = arith.index_cast %dim_936 : index to i64 + %dim_937 = tensor.dim %2728, %c0 : tensor + %2734 = arith.index_cast %dim_937 : index to i64 + %2735 = arith.maxsi %2733, %2734 : i64 + %2736 = arith.index_cast %2735 : i64 to index + %from_elements_938 = tensor.from_elements %2736, %c4096 : tensor<2xindex> + %2737 = stablehlo.dynamic_broadcast_in_dim %2692, %from_elements_938, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_939 = tensor.dim %2737, %c0 : tensor + %2738 = arith.index_cast %dim_939 : index to i64 + %from_elements_940 = tensor.from_elements %2738, %c4096_i64 : tensor<2xi64> + %2739 = stablehlo.real_dynamic_slice %2732, %c_22, %from_elements_940, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_941 = tensor.from_elements %2738, %c4096_i64, %c1_i64 : tensor<3xi64> + %2740 = stablehlo.dynamic_reshape %2737, %from_elements_941 : (tensor, tensor<3xi64>) -> tensor + %2741 = stablehlo.dynamic_iota %from_elements_941, dim = 1 : (tensor<3xi64>) -> tensor + %2742 = stablehlo.concatenate %2740, %2741, dim = 2 : (tensor, tensor) -> tensor + %2743 = "stablehlo.scatter"(%2680, %2742, %2739) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2744 = stablehlo.slice %2615 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2745 = stablehlo.reshape %2744 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2746 = stablehlo.custom_call @byteir.non_zero(%2745) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_942 = tensor.dim %2746, %c0 : tensor + %2747 = arith.index_cast %dim_942 : index to i64 + %from_elements_943 = tensor.from_elements %2747, %c1_i64 : tensor<2xi64> + %2748 = stablehlo.real_dynamic_slice %2746, %c_22, %from_elements_943, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_944 = tensor.dim %2748, %c0 : tensor + %2749 = arith.index_cast %dim_944 : index to i64 + %from_elements_945 = tensor.from_elements %2749 : tensor<1xi64> + %2750 = stablehlo.dynamic_reshape %2748, %from_elements_945 : (tensor, tensor<1xi64>) -> tensor + %from_elements_946 = tensor.from_elements %2747, %c2_i64 : tensor<2xi64> + %2751 = stablehlo.real_dynamic_slice %2746, %c_24, %from_elements_946, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_947 = tensor.dim %2751, %c0 : tensor + %2752 = arith.index_cast %dim_947 : index to i64 + %from_elements_948 = tensor.from_elements %2752 : tensor<1xi64> + %2753 = stablehlo.dynamic_reshape %2751, %from_elements_948 : (tensor, tensor<1xi64>) -> tensor + %dim_949 = tensor.dim %2753, %c0 : tensor + %2754 = arith.index_cast %dim_949 : index to i64 + %from_elements_950 = tensor.from_elements %2754, %c1_i64 : tensor<2xi64> + %2755 = stablehlo.dynamic_reshape %2753, %from_elements_950 : (tensor, tensor<2xi64>) -> tensor + %dim_951 = tensor.dim %2755, %c0 : tensor + %2756 = arith.index_cast %dim_951 : index to i64 + %from_elements_952 = tensor.from_elements %c1_i64, %2756, %c4096_i64 : tensor<3xi64> + %2757 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_952, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_953 = tensor.dim %2757, %c1 : tensor<1x?x4096xi64> + %2758 = arith.index_cast %dim_953 : index to i64 + %from_elements_954 = tensor.from_elements %c1_i64, %2758, %c4096_i64, %c1_i64 : tensor<4xi64> + %2759 = stablehlo.dynamic_reshape %2757, %from_elements_954 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2760 = stablehlo.dynamic_broadcast_in_dim %2755, %from_elements_952, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_955 = tensor.dim %2760, %c1 : tensor<1x?x4096xi64> + %2761 = arith.index_cast %dim_955 : index to i64 + %from_elements_956 = tensor.from_elements %c1_i64, %2761, %c4096_i64, %c1_i64 : tensor<4xi64> + %2762 = stablehlo.dynamic_reshape %2760, %from_elements_956 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2763 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_952, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_957 = tensor.dim %2763, %c1 : tensor<1x?x4096xi64> + %2764 = arith.index_cast %dim_957 : index to i64 + %from_elements_958 = tensor.from_elements %c1_i64, %2764, %c4096_i64, %c1_i64 : tensor<4xi64> + %2765 = stablehlo.dynamic_reshape %2763, %from_elements_958 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2766 = stablehlo.concatenate %2759, %2762, %2765, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2767 = "stablehlo.gather"(%2626, %2766) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2768 = shape.shape_of %2767 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2769 = shape.num_elements %2768 : tensor<3xindex> -> index + %2770 = stablehlo.compute_reshape_shape %2769, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2771 = stablehlo.dynamic_reshape %2767, %2770 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2772 = stablehlo.dot %2771, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2773 = stablehlo.logistic %2772 : tensor + %2774 = shape.shape_of %2773 : tensor -> tensor<2xindex> + %2775 = shape.shape_of %2772 : tensor -> tensor<2xindex> + %2776 = shape.cstr_broadcastable %2774, %2775 : tensor<2xindex>, tensor<2xindex> + %2777 = shape.assuming %2776 -> (tensor) { + %19688 = shape.broadcast %2774, %2775 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2773, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2772, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2778 = shape.shape_of %2777 : tensor -> tensor<2xindex> + %2779 = shape.cstr_broadcastable %2778, %2775 : tensor<2xindex>, tensor<2xindex> + %2780 = shape.assuming %2779 -> (tensor) { + %19688 = shape.broadcast %2778, %2775 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2777, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2772, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2781 = stablehlo.dot %2780, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_959 = tensor.dim %2753, %c0 : tensor + %2782 = arith.index_cast %dim_959 : index to i64 + %from_elements_960 = tensor.from_elements %2782, %c1_i64 : tensor<2xi64> + %2783 = stablehlo.dynamic_reshape %2753, %from_elements_960 : (tensor, tensor<2xi64>) -> tensor + %dim_961 = tensor.dim %2750, %c0 : tensor + %2784 = arith.index_cast %dim_961 : index to i64 + %from_elements_962 = tensor.from_elements %2784, %c1_i64 : tensor<2xi64> + %2785 = stablehlo.dynamic_reshape %2750, %from_elements_962 : (tensor, tensor<2xi64>) -> tensor + %2786 = stablehlo.concatenate %2783, %2785, dim = 1 : (tensor, tensor) -> tensor + %2787 = "stablehlo.gather"(%2655, %2786) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2788 = shape.shape_of %2781 : tensor -> tensor<2xindex> + %2789 = shape.shape_of %2787 : tensor -> tensor<2xindex> + %2790 = shape.cstr_broadcastable %2788, %2789 : tensor<2xindex>, tensor<2xindex> + %2791 = shape.assuming %2790 -> (tensor) { + %19688 = shape.broadcast %2788, %2789 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2781, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2787, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2792 = shape.shape_of %2791 : tensor -> tensor<2xindex> + %2793 = stablehlo.dynamic_broadcast_in_dim %2791, %2792, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2794 = stablehlo.dynamic_broadcast_in_dim %213, %2792, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2795 = stablehlo.multiply %2793, %2794 : tensor + %dim_963 = tensor.dim %2755, %c0 : tensor + %2796 = arith.index_cast %dim_963 : index to i64 + %dim_964 = tensor.dim %2791, %c0 : tensor + %2797 = arith.index_cast %dim_964 : index to i64 + %2798 = arith.maxsi %2796, %2797 : i64 + %2799 = arith.index_cast %2798 : i64 to index + %from_elements_965 = tensor.from_elements %2799, %c4096 : tensor<2xindex> + %2800 = stablehlo.dynamic_broadcast_in_dim %2755, %from_elements_965, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_966 = tensor.dim %2800, %c0 : tensor + %2801 = arith.index_cast %dim_966 : index to i64 + %from_elements_967 = tensor.from_elements %2801, %c4096_i64 : tensor<2xi64> + %2802 = stablehlo.real_dynamic_slice %2795, %c_22, %from_elements_967, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_968 = tensor.from_elements %2801, %c4096_i64, %c1_i64 : tensor<3xi64> + %2803 = stablehlo.dynamic_reshape %2800, %from_elements_968 : (tensor, tensor<3xi64>) -> tensor + %2804 = stablehlo.dynamic_iota %from_elements_968, dim = 1 : (tensor<3xi64>) -> tensor + %2805 = stablehlo.concatenate %2803, %2804, dim = 2 : (tensor, tensor) -> tensor + %2806 = "stablehlo.scatter"(%2743, %2805, %2802) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2807 = stablehlo.slice %2615 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2808 = stablehlo.reshape %2807 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2809 = stablehlo.custom_call @byteir.non_zero(%2808) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_969 = tensor.dim %2809, %c0 : tensor + %2810 = arith.index_cast %dim_969 : index to i64 + %from_elements_970 = tensor.from_elements %2810, %c1_i64 : tensor<2xi64> + %2811 = stablehlo.real_dynamic_slice %2809, %c_22, %from_elements_970, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_971 = tensor.dim %2811, %c0 : tensor + %2812 = arith.index_cast %dim_971 : index to i64 + %from_elements_972 = tensor.from_elements %2812 : tensor<1xi64> + %2813 = stablehlo.dynamic_reshape %2811, %from_elements_972 : (tensor, tensor<1xi64>) -> tensor + %from_elements_973 = tensor.from_elements %2810, %c2_i64 : tensor<2xi64> + %2814 = stablehlo.real_dynamic_slice %2809, %c_24, %from_elements_973, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_974 = tensor.dim %2814, %c0 : tensor + %2815 = arith.index_cast %dim_974 : index to i64 + %from_elements_975 = tensor.from_elements %2815 : tensor<1xi64> + %2816 = stablehlo.dynamic_reshape %2814, %from_elements_975 : (tensor, tensor<1xi64>) -> tensor + %dim_976 = tensor.dim %2816, %c0 : tensor + %2817 = arith.index_cast %dim_976 : index to i64 + %from_elements_977 = tensor.from_elements %2817, %c1_i64 : tensor<2xi64> + %2818 = stablehlo.dynamic_reshape %2816, %from_elements_977 : (tensor, tensor<2xi64>) -> tensor + %dim_978 = tensor.dim %2818, %c0 : tensor + %2819 = arith.index_cast %dim_978 : index to i64 + %from_elements_979 = tensor.from_elements %c1_i64, %2819, %c4096_i64 : tensor<3xi64> + %2820 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_979, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_980 = tensor.dim %2820, %c1 : tensor<1x?x4096xi64> + %2821 = arith.index_cast %dim_980 : index to i64 + %from_elements_981 = tensor.from_elements %c1_i64, %2821, %c4096_i64, %c1_i64 : tensor<4xi64> + %2822 = stablehlo.dynamic_reshape %2820, %from_elements_981 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2823 = stablehlo.dynamic_broadcast_in_dim %2818, %from_elements_979, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_982 = tensor.dim %2823, %c1 : tensor<1x?x4096xi64> + %2824 = arith.index_cast %dim_982 : index to i64 + %from_elements_983 = tensor.from_elements %c1_i64, %2824, %c4096_i64, %c1_i64 : tensor<4xi64> + %2825 = stablehlo.dynamic_reshape %2823, %from_elements_983 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2826 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_979, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_984 = tensor.dim %2826, %c1 : tensor<1x?x4096xi64> + %2827 = arith.index_cast %dim_984 : index to i64 + %from_elements_985 = tensor.from_elements %c1_i64, %2827, %c4096_i64, %c1_i64 : tensor<4xi64> + %2828 = stablehlo.dynamic_reshape %2826, %from_elements_985 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2829 = stablehlo.concatenate %2822, %2825, %2828, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2830 = "stablehlo.gather"(%2626, %2829) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2831 = shape.shape_of %2830 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2832 = shape.num_elements %2831 : tensor<3xindex> -> index + %2833 = stablehlo.compute_reshape_shape %2832, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2834 = stablehlo.dynamic_reshape %2830, %2833 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2835 = stablehlo.dot %2834, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2836 = stablehlo.logistic %2835 : tensor + %2837 = shape.shape_of %2836 : tensor -> tensor<2xindex> + %2838 = shape.shape_of %2835 : tensor -> tensor<2xindex> + %2839 = shape.cstr_broadcastable %2837, %2838 : tensor<2xindex>, tensor<2xindex> + %2840 = shape.assuming %2839 -> (tensor) { + %19688 = shape.broadcast %2837, %2838 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2836, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2835, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2841 = shape.shape_of %2840 : tensor -> tensor<2xindex> + %2842 = shape.cstr_broadcastable %2841, %2838 : tensor<2xindex>, tensor<2xindex> + %2843 = shape.assuming %2842 -> (tensor) { + %19688 = shape.broadcast %2841, %2838 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2840, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2835, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2844 = stablehlo.dot %2843, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_986 = tensor.dim %2816, %c0 : tensor + %2845 = arith.index_cast %dim_986 : index to i64 + %from_elements_987 = tensor.from_elements %2845, %c1_i64 : tensor<2xi64> + %2846 = stablehlo.dynamic_reshape %2816, %from_elements_987 : (tensor, tensor<2xi64>) -> tensor + %dim_988 = tensor.dim %2813, %c0 : tensor + %2847 = arith.index_cast %dim_988 : index to i64 + %from_elements_989 = tensor.from_elements %2847, %c1_i64 : tensor<2xi64> + %2848 = stablehlo.dynamic_reshape %2813, %from_elements_989 : (tensor, tensor<2xi64>) -> tensor + %2849 = stablehlo.concatenate %2846, %2848, dim = 1 : (tensor, tensor) -> tensor + %2850 = "stablehlo.gather"(%2655, %2849) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2851 = shape.shape_of %2844 : tensor -> tensor<2xindex> + %2852 = shape.shape_of %2850 : tensor -> tensor<2xindex> + %2853 = shape.cstr_broadcastable %2851, %2852 : tensor<2xindex>, tensor<2xindex> + %2854 = shape.assuming %2853 -> (tensor) { + %19688 = shape.broadcast %2851, %2852 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2844, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2850, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2855 = shape.shape_of %2854 : tensor -> tensor<2xindex> + %2856 = stablehlo.dynamic_broadcast_in_dim %2854, %2855, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2857 = stablehlo.dynamic_broadcast_in_dim %213, %2855, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2858 = stablehlo.multiply %2856, %2857 : tensor + %dim_990 = tensor.dim %2818, %c0 : tensor + %2859 = arith.index_cast %dim_990 : index to i64 + %dim_991 = tensor.dim %2854, %c0 : tensor + %2860 = arith.index_cast %dim_991 : index to i64 + %2861 = arith.maxsi %2859, %2860 : i64 + %2862 = arith.index_cast %2861 : i64 to index + %from_elements_992 = tensor.from_elements %2862, %c4096 : tensor<2xindex> + %2863 = stablehlo.dynamic_broadcast_in_dim %2818, %from_elements_992, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_993 = tensor.dim %2863, %c0 : tensor + %2864 = arith.index_cast %dim_993 : index to i64 + %from_elements_994 = tensor.from_elements %2864, %c4096_i64 : tensor<2xi64> + %2865 = stablehlo.real_dynamic_slice %2858, %c_22, %from_elements_994, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_995 = tensor.from_elements %2864, %c4096_i64, %c1_i64 : tensor<3xi64> + %2866 = stablehlo.dynamic_reshape %2863, %from_elements_995 : (tensor, tensor<3xi64>) -> tensor + %2867 = stablehlo.dynamic_iota %from_elements_995, dim = 1 : (tensor<3xi64>) -> tensor + %2868 = stablehlo.concatenate %2866, %2867, dim = 2 : (tensor, tensor) -> tensor + %2869 = "stablehlo.scatter"(%2806, %2868, %2865) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2870 = stablehlo.slice %2615 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2871 = stablehlo.reshape %2870 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2872 = stablehlo.custom_call @byteir.non_zero(%2871) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_996 = tensor.dim %2872, %c0 : tensor + %2873 = arith.index_cast %dim_996 : index to i64 + %from_elements_997 = tensor.from_elements %2873, %c1_i64 : tensor<2xi64> + %2874 = stablehlo.real_dynamic_slice %2872, %c_22, %from_elements_997, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_998 = tensor.dim %2874, %c0 : tensor + %2875 = arith.index_cast %dim_998 : index to i64 + %from_elements_999 = tensor.from_elements %2875 : tensor<1xi64> + %2876 = stablehlo.dynamic_reshape %2874, %from_elements_999 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1000 = tensor.from_elements %2873, %c2_i64 : tensor<2xi64> + %2877 = stablehlo.real_dynamic_slice %2872, %c_24, %from_elements_1000, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1001 = tensor.dim %2877, %c0 : tensor + %2878 = arith.index_cast %dim_1001 : index to i64 + %from_elements_1002 = tensor.from_elements %2878 : tensor<1xi64> + %2879 = stablehlo.dynamic_reshape %2877, %from_elements_1002 : (tensor, tensor<1xi64>) -> tensor + %dim_1003 = tensor.dim %2879, %c0 : tensor + %2880 = arith.index_cast %dim_1003 : index to i64 + %from_elements_1004 = tensor.from_elements %2880, %c1_i64 : tensor<2xi64> + %2881 = stablehlo.dynamic_reshape %2879, %from_elements_1004 : (tensor, tensor<2xi64>) -> tensor + %dim_1005 = tensor.dim %2881, %c0 : tensor + %2882 = arith.index_cast %dim_1005 : index to i64 + %from_elements_1006 = tensor.from_elements %c1_i64, %2882, %c4096_i64 : tensor<3xi64> + %2883 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1006, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1007 = tensor.dim %2883, %c1 : tensor<1x?x4096xi64> + %2884 = arith.index_cast %dim_1007 : index to i64 + %from_elements_1008 = tensor.from_elements %c1_i64, %2884, %c4096_i64, %c1_i64 : tensor<4xi64> + %2885 = stablehlo.dynamic_reshape %2883, %from_elements_1008 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2886 = stablehlo.dynamic_broadcast_in_dim %2881, %from_elements_1006, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1009 = tensor.dim %2886, %c1 : tensor<1x?x4096xi64> + %2887 = arith.index_cast %dim_1009 : index to i64 + %from_elements_1010 = tensor.from_elements %c1_i64, %2887, %c4096_i64, %c1_i64 : tensor<4xi64> + %2888 = stablehlo.dynamic_reshape %2886, %from_elements_1010 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2889 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1006, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1011 = tensor.dim %2889, %c1 : tensor<1x?x4096xi64> + %2890 = arith.index_cast %dim_1011 : index to i64 + %from_elements_1012 = tensor.from_elements %c1_i64, %2890, %c4096_i64, %c1_i64 : tensor<4xi64> + %2891 = stablehlo.dynamic_reshape %2889, %from_elements_1012 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2892 = stablehlo.concatenate %2885, %2888, %2891, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2893 = "stablehlo.gather"(%2626, %2892) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2894 = shape.shape_of %2893 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2895 = shape.num_elements %2894 : tensor<3xindex> -> index + %2896 = stablehlo.compute_reshape_shape %2895, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2897 = stablehlo.dynamic_reshape %2893, %2896 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2898 = stablehlo.dot %2897, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2899 = stablehlo.logistic %2898 : tensor + %2900 = shape.shape_of %2899 : tensor -> tensor<2xindex> + %2901 = shape.shape_of %2898 : tensor -> tensor<2xindex> + %2902 = shape.cstr_broadcastable %2900, %2901 : tensor<2xindex>, tensor<2xindex> + %2903 = shape.assuming %2902 -> (tensor) { + %19688 = shape.broadcast %2900, %2901 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2899, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2898, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2904 = shape.shape_of %2903 : tensor -> tensor<2xindex> + %2905 = shape.cstr_broadcastable %2904, %2901 : tensor<2xindex>, tensor<2xindex> + %2906 = shape.assuming %2905 -> (tensor) { + %19688 = shape.broadcast %2904, %2901 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2903, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2898, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2907 = stablehlo.dot %2906, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1013 = tensor.dim %2879, %c0 : tensor + %2908 = arith.index_cast %dim_1013 : index to i64 + %from_elements_1014 = tensor.from_elements %2908, %c1_i64 : tensor<2xi64> + %2909 = stablehlo.dynamic_reshape %2879, %from_elements_1014 : (tensor, tensor<2xi64>) -> tensor + %dim_1015 = tensor.dim %2876, %c0 : tensor + %2910 = arith.index_cast %dim_1015 : index to i64 + %from_elements_1016 = tensor.from_elements %2910, %c1_i64 : tensor<2xi64> + %2911 = stablehlo.dynamic_reshape %2876, %from_elements_1016 : (tensor, tensor<2xi64>) -> tensor + %2912 = stablehlo.concatenate %2909, %2911, dim = 1 : (tensor, tensor) -> tensor + %2913 = "stablehlo.gather"(%2655, %2912) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2914 = shape.shape_of %2907 : tensor -> tensor<2xindex> + %2915 = shape.shape_of %2913 : tensor -> tensor<2xindex> + %2916 = shape.cstr_broadcastable %2914, %2915 : tensor<2xindex>, tensor<2xindex> + %2917 = shape.assuming %2916 -> (tensor) { + %19688 = shape.broadcast %2914, %2915 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2907, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2913, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2918 = shape.shape_of %2917 : tensor -> tensor<2xindex> + %2919 = stablehlo.dynamic_broadcast_in_dim %2917, %2918, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2920 = stablehlo.dynamic_broadcast_in_dim %213, %2918, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2921 = stablehlo.multiply %2919, %2920 : tensor + %dim_1017 = tensor.dim %2881, %c0 : tensor + %2922 = arith.index_cast %dim_1017 : index to i64 + %dim_1018 = tensor.dim %2917, %c0 : tensor + %2923 = arith.index_cast %dim_1018 : index to i64 + %2924 = arith.maxsi %2922, %2923 : i64 + %2925 = arith.index_cast %2924 : i64 to index + %from_elements_1019 = tensor.from_elements %2925, %c4096 : tensor<2xindex> + %2926 = stablehlo.dynamic_broadcast_in_dim %2881, %from_elements_1019, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1020 = tensor.dim %2926, %c0 : tensor + %2927 = arith.index_cast %dim_1020 : index to i64 + %from_elements_1021 = tensor.from_elements %2927, %c4096_i64 : tensor<2xi64> + %2928 = stablehlo.real_dynamic_slice %2921, %c_22, %from_elements_1021, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1022 = tensor.from_elements %2927, %c4096_i64, %c1_i64 : tensor<3xi64> + %2929 = stablehlo.dynamic_reshape %2926, %from_elements_1022 : (tensor, tensor<3xi64>) -> tensor + %2930 = stablehlo.dynamic_iota %from_elements_1022, dim = 1 : (tensor<3xi64>) -> tensor + %2931 = stablehlo.concatenate %2929, %2930, dim = 2 : (tensor, tensor) -> tensor + %2932 = "stablehlo.scatter"(%2869, %2931, %2928) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2933 = stablehlo.slice %2615 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2934 = stablehlo.reshape %2933 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2935 = stablehlo.custom_call @byteir.non_zero(%2934) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1023 = tensor.dim %2935, %c0 : tensor + %2936 = arith.index_cast %dim_1023 : index to i64 + %from_elements_1024 = tensor.from_elements %2936, %c1_i64 : tensor<2xi64> + %2937 = stablehlo.real_dynamic_slice %2935, %c_22, %from_elements_1024, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1025 = tensor.dim %2937, %c0 : tensor + %2938 = arith.index_cast %dim_1025 : index to i64 + %from_elements_1026 = tensor.from_elements %2938 : tensor<1xi64> + %2939 = stablehlo.dynamic_reshape %2937, %from_elements_1026 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1027 = tensor.from_elements %2936, %c2_i64 : tensor<2xi64> + %2940 = stablehlo.real_dynamic_slice %2935, %c_24, %from_elements_1027, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1028 = tensor.dim %2940, %c0 : tensor + %2941 = arith.index_cast %dim_1028 : index to i64 + %from_elements_1029 = tensor.from_elements %2941 : tensor<1xi64> + %2942 = stablehlo.dynamic_reshape %2940, %from_elements_1029 : (tensor, tensor<1xi64>) -> tensor + %dim_1030 = tensor.dim %2942, %c0 : tensor + %2943 = arith.index_cast %dim_1030 : index to i64 + %from_elements_1031 = tensor.from_elements %2943, %c1_i64 : tensor<2xi64> + %2944 = stablehlo.dynamic_reshape %2942, %from_elements_1031 : (tensor, tensor<2xi64>) -> tensor + %dim_1032 = tensor.dim %2944, %c0 : tensor + %2945 = arith.index_cast %dim_1032 : index to i64 + %from_elements_1033 = tensor.from_elements %c1_i64, %2945, %c4096_i64 : tensor<3xi64> + %2946 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1033, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1034 = tensor.dim %2946, %c1 : tensor<1x?x4096xi64> + %2947 = arith.index_cast %dim_1034 : index to i64 + %from_elements_1035 = tensor.from_elements %c1_i64, %2947, %c4096_i64, %c1_i64 : tensor<4xi64> + %2948 = stablehlo.dynamic_reshape %2946, %from_elements_1035 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2949 = stablehlo.dynamic_broadcast_in_dim %2944, %from_elements_1033, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1036 = tensor.dim %2949, %c1 : tensor<1x?x4096xi64> + %2950 = arith.index_cast %dim_1036 : index to i64 + %from_elements_1037 = tensor.from_elements %c1_i64, %2950, %c4096_i64, %c1_i64 : tensor<4xi64> + %2951 = stablehlo.dynamic_reshape %2949, %from_elements_1037 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2952 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1033, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1038 = tensor.dim %2952, %c1 : tensor<1x?x4096xi64> + %2953 = arith.index_cast %dim_1038 : index to i64 + %from_elements_1039 = tensor.from_elements %c1_i64, %2953, %c4096_i64, %c1_i64 : tensor<4xi64> + %2954 = stablehlo.dynamic_reshape %2952, %from_elements_1039 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %2955 = stablehlo.concatenate %2948, %2951, %2954, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %2956 = "stablehlo.gather"(%2626, %2955) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %2957 = shape.shape_of %2956 : tensor<1x?x4096xf32> -> tensor<3xindex> + %2958 = shape.num_elements %2957 : tensor<3xindex> -> index + %2959 = stablehlo.compute_reshape_shape %2958, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %2960 = stablehlo.dynamic_reshape %2956, %2959 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %2961 = stablehlo.dot %2960, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %2962 = stablehlo.logistic %2961 : tensor + %2963 = shape.shape_of %2962 : tensor -> tensor<2xindex> + %2964 = shape.shape_of %2961 : tensor -> tensor<2xindex> + %2965 = shape.cstr_broadcastable %2963, %2964 : tensor<2xindex>, tensor<2xindex> + %2966 = shape.assuming %2965 -> (tensor) { + %19688 = shape.broadcast %2963, %2964 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2962, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2961, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2967 = shape.shape_of %2966 : tensor -> tensor<2xindex> + %2968 = shape.cstr_broadcastable %2967, %2964 : tensor<2xindex>, tensor<2xindex> + %2969 = shape.assuming %2968 -> (tensor) { + %19688 = shape.broadcast %2967, %2964 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2966, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2961, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2970 = stablehlo.dot %2969, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1040 = tensor.dim %2942, %c0 : tensor + %2971 = arith.index_cast %dim_1040 : index to i64 + %from_elements_1041 = tensor.from_elements %2971, %c1_i64 : tensor<2xi64> + %2972 = stablehlo.dynamic_reshape %2942, %from_elements_1041 : (tensor, tensor<2xi64>) -> tensor + %dim_1042 = tensor.dim %2939, %c0 : tensor + %2973 = arith.index_cast %dim_1042 : index to i64 + %from_elements_1043 = tensor.from_elements %2973, %c1_i64 : tensor<2xi64> + %2974 = stablehlo.dynamic_reshape %2939, %from_elements_1043 : (tensor, tensor<2xi64>) -> tensor + %2975 = stablehlo.concatenate %2972, %2974, dim = 1 : (tensor, tensor) -> tensor + %2976 = "stablehlo.gather"(%2655, %2975) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %2977 = shape.shape_of %2970 : tensor -> tensor<2xindex> + %2978 = shape.shape_of %2976 : tensor -> tensor<2xindex> + %2979 = shape.cstr_broadcastable %2977, %2978 : tensor<2xindex>, tensor<2xindex> + %2980 = shape.assuming %2979 -> (tensor) { + %19688 = shape.broadcast %2977, %2978 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %2970, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %2976, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %2981 = shape.shape_of %2980 : tensor -> tensor<2xindex> + %2982 = stablehlo.dynamic_broadcast_in_dim %2980, %2981, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %2983 = stablehlo.dynamic_broadcast_in_dim %213, %2981, dims = [] : (tensor, tensor<2xindex>) -> tensor + %2984 = stablehlo.multiply %2982, %2983 : tensor + %dim_1044 = tensor.dim %2944, %c0 : tensor + %2985 = arith.index_cast %dim_1044 : index to i64 + %dim_1045 = tensor.dim %2980, %c0 : tensor + %2986 = arith.index_cast %dim_1045 : index to i64 + %2987 = arith.maxsi %2985, %2986 : i64 + %2988 = arith.index_cast %2987 : i64 to index + %from_elements_1046 = tensor.from_elements %2988, %c4096 : tensor<2xindex> + %2989 = stablehlo.dynamic_broadcast_in_dim %2944, %from_elements_1046, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1047 = tensor.dim %2989, %c0 : tensor + %2990 = arith.index_cast %dim_1047 : index to i64 + %from_elements_1048 = tensor.from_elements %2990, %c4096_i64 : tensor<2xi64> + %2991 = stablehlo.real_dynamic_slice %2984, %c_22, %from_elements_1048, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1049 = tensor.from_elements %2990, %c4096_i64, %c1_i64 : tensor<3xi64> + %2992 = stablehlo.dynamic_reshape %2989, %from_elements_1049 : (tensor, tensor<3xi64>) -> tensor + %2993 = stablehlo.dynamic_iota %from_elements_1049, dim = 1 : (tensor<3xi64>) -> tensor + %2994 = stablehlo.concatenate %2992, %2993, dim = 2 : (tensor, tensor) -> tensor + %2995 = "stablehlo.scatter"(%2932, %2994, %2991) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %2996 = stablehlo.slice %2615 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %2997 = stablehlo.reshape %2996 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %2998 = stablehlo.custom_call @byteir.non_zero(%2997) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1050 = tensor.dim %2998, %c0 : tensor + %2999 = arith.index_cast %dim_1050 : index to i64 + %from_elements_1051 = tensor.from_elements %2999, %c1_i64 : tensor<2xi64> + %3000 = stablehlo.real_dynamic_slice %2998, %c_22, %from_elements_1051, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1052 = tensor.dim %3000, %c0 : tensor + %3001 = arith.index_cast %dim_1052 : index to i64 + %from_elements_1053 = tensor.from_elements %3001 : tensor<1xi64> + %3002 = stablehlo.dynamic_reshape %3000, %from_elements_1053 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1054 = tensor.from_elements %2999, %c2_i64 : tensor<2xi64> + %3003 = stablehlo.real_dynamic_slice %2998, %c_24, %from_elements_1054, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1055 = tensor.dim %3003, %c0 : tensor + %3004 = arith.index_cast %dim_1055 : index to i64 + %from_elements_1056 = tensor.from_elements %3004 : tensor<1xi64> + %3005 = stablehlo.dynamic_reshape %3003, %from_elements_1056 : (tensor, tensor<1xi64>) -> tensor + %dim_1057 = tensor.dim %3005, %c0 : tensor + %3006 = arith.index_cast %dim_1057 : index to i64 + %from_elements_1058 = tensor.from_elements %3006, %c1_i64 : tensor<2xi64> + %3007 = stablehlo.dynamic_reshape %3005, %from_elements_1058 : (tensor, tensor<2xi64>) -> tensor + %dim_1059 = tensor.dim %3007, %c0 : tensor + %3008 = arith.index_cast %dim_1059 : index to i64 + %from_elements_1060 = tensor.from_elements %c1_i64, %3008, %c4096_i64 : tensor<3xi64> + %3009 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1060, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1061 = tensor.dim %3009, %c1 : tensor<1x?x4096xi64> + %3010 = arith.index_cast %dim_1061 : index to i64 + %from_elements_1062 = tensor.from_elements %c1_i64, %3010, %c4096_i64, %c1_i64 : tensor<4xi64> + %3011 = stablehlo.dynamic_reshape %3009, %from_elements_1062 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3012 = stablehlo.dynamic_broadcast_in_dim %3007, %from_elements_1060, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1063 = tensor.dim %3012, %c1 : tensor<1x?x4096xi64> + %3013 = arith.index_cast %dim_1063 : index to i64 + %from_elements_1064 = tensor.from_elements %c1_i64, %3013, %c4096_i64, %c1_i64 : tensor<4xi64> + %3014 = stablehlo.dynamic_reshape %3012, %from_elements_1064 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3015 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1060, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1065 = tensor.dim %3015, %c1 : tensor<1x?x4096xi64> + %3016 = arith.index_cast %dim_1065 : index to i64 + %from_elements_1066 = tensor.from_elements %c1_i64, %3016, %c4096_i64, %c1_i64 : tensor<4xi64> + %3017 = stablehlo.dynamic_reshape %3015, %from_elements_1066 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3018 = stablehlo.concatenate %3011, %3014, %3017, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3019 = "stablehlo.gather"(%2626, %3018) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3020 = shape.shape_of %3019 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3021 = shape.num_elements %3020 : tensor<3xindex> -> index + %3022 = stablehlo.compute_reshape_shape %3021, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3023 = stablehlo.dynamic_reshape %3019, %3022 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3024 = stablehlo.dot %3023, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3025 = stablehlo.logistic %3024 : tensor + %3026 = shape.shape_of %3025 : tensor -> tensor<2xindex> + %3027 = shape.shape_of %3024 : tensor -> tensor<2xindex> + %3028 = shape.cstr_broadcastable %3026, %3027 : tensor<2xindex>, tensor<2xindex> + %3029 = shape.assuming %3028 -> (tensor) { + %19688 = shape.broadcast %3026, %3027 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3025, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3024, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3030 = shape.shape_of %3029 : tensor -> tensor<2xindex> + %3031 = shape.cstr_broadcastable %3030, %3027 : tensor<2xindex>, tensor<2xindex> + %3032 = shape.assuming %3031 -> (tensor) { + %19688 = shape.broadcast %3030, %3027 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3029, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3024, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3033 = stablehlo.dot %3032, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1067 = tensor.dim %3005, %c0 : tensor + %3034 = arith.index_cast %dim_1067 : index to i64 + %from_elements_1068 = tensor.from_elements %3034, %c1_i64 : tensor<2xi64> + %3035 = stablehlo.dynamic_reshape %3005, %from_elements_1068 : (tensor, tensor<2xi64>) -> tensor + %dim_1069 = tensor.dim %3002, %c0 : tensor + %3036 = arith.index_cast %dim_1069 : index to i64 + %from_elements_1070 = tensor.from_elements %3036, %c1_i64 : tensor<2xi64> + %3037 = stablehlo.dynamic_reshape %3002, %from_elements_1070 : (tensor, tensor<2xi64>) -> tensor + %3038 = stablehlo.concatenate %3035, %3037, dim = 1 : (tensor, tensor) -> tensor + %3039 = "stablehlo.gather"(%2655, %3038) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3040 = shape.shape_of %3033 : tensor -> tensor<2xindex> + %3041 = shape.shape_of %3039 : tensor -> tensor<2xindex> + %3042 = shape.cstr_broadcastable %3040, %3041 : tensor<2xindex>, tensor<2xindex> + %3043 = shape.assuming %3042 -> (tensor) { + %19688 = shape.broadcast %3040, %3041 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3033, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3039, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3044 = shape.shape_of %3043 : tensor -> tensor<2xindex> + %3045 = stablehlo.dynamic_broadcast_in_dim %3043, %3044, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3046 = stablehlo.dynamic_broadcast_in_dim %213, %3044, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3047 = stablehlo.multiply %3045, %3046 : tensor + %dim_1071 = tensor.dim %3007, %c0 : tensor + %3048 = arith.index_cast %dim_1071 : index to i64 + %dim_1072 = tensor.dim %3043, %c0 : tensor + %3049 = arith.index_cast %dim_1072 : index to i64 + %3050 = arith.maxsi %3048, %3049 : i64 + %3051 = arith.index_cast %3050 : i64 to index + %from_elements_1073 = tensor.from_elements %3051, %c4096 : tensor<2xindex> + %3052 = stablehlo.dynamic_broadcast_in_dim %3007, %from_elements_1073, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1074 = tensor.dim %3052, %c0 : tensor + %3053 = arith.index_cast %dim_1074 : index to i64 + %from_elements_1075 = tensor.from_elements %3053, %c4096_i64 : tensor<2xi64> + %3054 = stablehlo.real_dynamic_slice %3047, %c_22, %from_elements_1075, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1076 = tensor.from_elements %3053, %c4096_i64, %c1_i64 : tensor<3xi64> + %3055 = stablehlo.dynamic_reshape %3052, %from_elements_1076 : (tensor, tensor<3xi64>) -> tensor + %3056 = stablehlo.dynamic_iota %from_elements_1076, dim = 1 : (tensor<3xi64>) -> tensor + %3057 = stablehlo.concatenate %3055, %3056, dim = 2 : (tensor, tensor) -> tensor + %3058 = "stablehlo.scatter"(%2995, %3057, %3054) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3059 = stablehlo.slice %2615 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3060 = stablehlo.reshape %3059 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3061 = stablehlo.custom_call @byteir.non_zero(%3060) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1077 = tensor.dim %3061, %c0 : tensor + %3062 = arith.index_cast %dim_1077 : index to i64 + %from_elements_1078 = tensor.from_elements %3062, %c1_i64 : tensor<2xi64> + %3063 = stablehlo.real_dynamic_slice %3061, %c_22, %from_elements_1078, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1079 = tensor.dim %3063, %c0 : tensor + %3064 = arith.index_cast %dim_1079 : index to i64 + %from_elements_1080 = tensor.from_elements %3064 : tensor<1xi64> + %3065 = stablehlo.dynamic_reshape %3063, %from_elements_1080 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1081 = tensor.from_elements %3062, %c2_i64 : tensor<2xi64> + %3066 = stablehlo.real_dynamic_slice %3061, %c_24, %from_elements_1081, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1082 = tensor.dim %3066, %c0 : tensor + %3067 = arith.index_cast %dim_1082 : index to i64 + %from_elements_1083 = tensor.from_elements %3067 : tensor<1xi64> + %3068 = stablehlo.dynamic_reshape %3066, %from_elements_1083 : (tensor, tensor<1xi64>) -> tensor + %dim_1084 = tensor.dim %3068, %c0 : tensor + %3069 = arith.index_cast %dim_1084 : index to i64 + %from_elements_1085 = tensor.from_elements %3069, %c1_i64 : tensor<2xi64> + %3070 = stablehlo.dynamic_reshape %3068, %from_elements_1085 : (tensor, tensor<2xi64>) -> tensor + %dim_1086 = tensor.dim %3070, %c0 : tensor + %3071 = arith.index_cast %dim_1086 : index to i64 + %from_elements_1087 = tensor.from_elements %c1_i64, %3071, %c4096_i64 : tensor<3xi64> + %3072 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1087, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1088 = tensor.dim %3072, %c1 : tensor<1x?x4096xi64> + %3073 = arith.index_cast %dim_1088 : index to i64 + %from_elements_1089 = tensor.from_elements %c1_i64, %3073, %c4096_i64, %c1_i64 : tensor<4xi64> + %3074 = stablehlo.dynamic_reshape %3072, %from_elements_1089 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3075 = stablehlo.dynamic_broadcast_in_dim %3070, %from_elements_1087, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1090 = tensor.dim %3075, %c1 : tensor<1x?x4096xi64> + %3076 = arith.index_cast %dim_1090 : index to i64 + %from_elements_1091 = tensor.from_elements %c1_i64, %3076, %c4096_i64, %c1_i64 : tensor<4xi64> + %3077 = stablehlo.dynamic_reshape %3075, %from_elements_1091 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3078 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1087, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1092 = tensor.dim %3078, %c1 : tensor<1x?x4096xi64> + %3079 = arith.index_cast %dim_1092 : index to i64 + %from_elements_1093 = tensor.from_elements %c1_i64, %3079, %c4096_i64, %c1_i64 : tensor<4xi64> + %3080 = stablehlo.dynamic_reshape %3078, %from_elements_1093 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3081 = stablehlo.concatenate %3074, %3077, %3080, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3082 = "stablehlo.gather"(%2626, %3081) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3083 = shape.shape_of %3082 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3084 = shape.num_elements %3083 : tensor<3xindex> -> index + %3085 = stablehlo.compute_reshape_shape %3084, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3086 = stablehlo.dynamic_reshape %3082, %3085 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3087 = stablehlo.dot %3086, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3088 = stablehlo.logistic %3087 : tensor + %3089 = shape.shape_of %3088 : tensor -> tensor<2xindex> + %3090 = shape.shape_of %3087 : tensor -> tensor<2xindex> + %3091 = shape.cstr_broadcastable %3089, %3090 : tensor<2xindex>, tensor<2xindex> + %3092 = shape.assuming %3091 -> (tensor) { + %19688 = shape.broadcast %3089, %3090 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3088, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3087, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3093 = shape.shape_of %3092 : tensor -> tensor<2xindex> + %3094 = shape.cstr_broadcastable %3093, %3090 : tensor<2xindex>, tensor<2xindex> + %3095 = shape.assuming %3094 -> (tensor) { + %19688 = shape.broadcast %3093, %3090 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3092, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3087, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3096 = stablehlo.dot %3095, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1094 = tensor.dim %3068, %c0 : tensor + %3097 = arith.index_cast %dim_1094 : index to i64 + %from_elements_1095 = tensor.from_elements %3097, %c1_i64 : tensor<2xi64> + %3098 = stablehlo.dynamic_reshape %3068, %from_elements_1095 : (tensor, tensor<2xi64>) -> tensor + %dim_1096 = tensor.dim %3065, %c0 : tensor + %3099 = arith.index_cast %dim_1096 : index to i64 + %from_elements_1097 = tensor.from_elements %3099, %c1_i64 : tensor<2xi64> + %3100 = stablehlo.dynamic_reshape %3065, %from_elements_1097 : (tensor, tensor<2xi64>) -> tensor + %3101 = stablehlo.concatenate %3098, %3100, dim = 1 : (tensor, tensor) -> tensor + %3102 = "stablehlo.gather"(%2655, %3101) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3103 = shape.shape_of %3096 : tensor -> tensor<2xindex> + %3104 = shape.shape_of %3102 : tensor -> tensor<2xindex> + %3105 = shape.cstr_broadcastable %3103, %3104 : tensor<2xindex>, tensor<2xindex> + %3106 = shape.assuming %3105 -> (tensor) { + %19688 = shape.broadcast %3103, %3104 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3096, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3102, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3107 = shape.shape_of %3106 : tensor -> tensor<2xindex> + %3108 = stablehlo.dynamic_broadcast_in_dim %3106, %3107, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3109 = stablehlo.dynamic_broadcast_in_dim %213, %3107, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3110 = stablehlo.multiply %3108, %3109 : tensor + %dim_1098 = tensor.dim %3070, %c0 : tensor + %3111 = arith.index_cast %dim_1098 : index to i64 + %dim_1099 = tensor.dim %3106, %c0 : tensor + %3112 = arith.index_cast %dim_1099 : index to i64 + %3113 = arith.maxsi %3111, %3112 : i64 + %3114 = arith.index_cast %3113 : i64 to index + %from_elements_1100 = tensor.from_elements %3114, %c4096 : tensor<2xindex> + %3115 = stablehlo.dynamic_broadcast_in_dim %3070, %from_elements_1100, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1101 = tensor.dim %3115, %c0 : tensor + %3116 = arith.index_cast %dim_1101 : index to i64 + %from_elements_1102 = tensor.from_elements %3116, %c4096_i64 : tensor<2xi64> + %3117 = stablehlo.real_dynamic_slice %3110, %c_22, %from_elements_1102, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1103 = tensor.from_elements %3116, %c4096_i64, %c1_i64 : tensor<3xi64> + %3118 = stablehlo.dynamic_reshape %3115, %from_elements_1103 : (tensor, tensor<3xi64>) -> tensor + %3119 = stablehlo.dynamic_iota %from_elements_1103, dim = 1 : (tensor<3xi64>) -> tensor + %3120 = stablehlo.concatenate %3118, %3119, dim = 2 : (tensor, tensor) -> tensor + %3121 = "stablehlo.scatter"(%3058, %3120, %3117) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3122 = stablehlo.reshape %3121 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %3123 = stablehlo.add %2588, %3122 : tensor<3x1x4096xf32> + %3124 = stablehlo.broadcast_in_dim %3123, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %3125 = stablehlo.power %3124, %15 : tensor<3x1x4096xf32> + %3126 = stablehlo.reduce(%3125 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %3127 = stablehlo.reshape %3126 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %3128 = stablehlo.broadcast_in_dim %3127, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %3129 = stablehlo.divide %3128, %21 : tensor<3x1x1xf32> + %3130 = stablehlo.broadcast_in_dim %3129, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %3131 = stablehlo.add %3130, %25 : tensor<3x1x1xf32> + %3132 = stablehlo.rsqrt %3131 : tensor<3x1x1xf32> + %3133 = stablehlo.broadcast_in_dim %3132, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %3134 = stablehlo.multiply %3124, %3133 : tensor<3x1x4096xf32> + %3135 = stablehlo.broadcast_in_dim %3134, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %3136 = stablehlo.multiply %3135, %31 : tensor<3x1x4096xf32> + %3137 = stablehlo.reshape %3136 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %3138 = stablehlo.dot %3137, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %3139 = stablehlo.reshape %3138 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %3140 = stablehlo.dot %3137, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %3141 = stablehlo.reshape %3140 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %3142 = stablehlo.reshape %3139 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %3143 = stablehlo.transpose %3142, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %3144 = stablehlo.reshape %3141 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %3145 = stablehlo.transpose %3144, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %3146 = stablehlo.slice %arg10 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %3147 = stablehlo.slice %arg11 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %3148 = "stablehlo.gather"(%3146, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %3149 = stablehlo.reshape %3148 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %3150 = "stablehlo.gather"(%3147, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %3151 = stablehlo.reshape %3150 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %3152 = stablehlo.broadcast_in_dim %3143, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %3153 = stablehlo.broadcast_in_dim %3149, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %3154 = stablehlo.multiply %3152, %3153 : tensor<3x32x1x128xf32> + %3155 = stablehlo.slice %3143 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %3156 = stablehlo.slice %3143 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %3157 = stablehlo.negate %3156 : tensor<3x32x1x64xf32> + %3158 = stablehlo.concatenate %3157, %3155, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %3159 = stablehlo.broadcast_in_dim %3158, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %3160 = stablehlo.broadcast_in_dim %3151, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %3161 = stablehlo.multiply %3159, %3160 : tensor<3x32x1x128xf32> + %3162 = stablehlo.add %3154, %3161 : tensor<3x32x1x128xf32> + %3163 = stablehlo.broadcast_in_dim %3145, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %3164 = stablehlo.broadcast_in_dim %3149, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %3165 = stablehlo.multiply %3163, %3164 : tensor<3x8x1x128xf32> + %3166 = stablehlo.slice %3145 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %3167 = stablehlo.slice %3145 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %3168 = stablehlo.negate %3167 : tensor<3x8x1x64xf32> + %3169 = stablehlo.concatenate %3168, %3166, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %3170 = stablehlo.broadcast_in_dim %3169, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %3171 = stablehlo.broadcast_in_dim %3151, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %3172 = stablehlo.multiply %3170, %3171 : tensor<3x8x1x128xf32> + %3173 = stablehlo.add %3165, %3172 : tensor<3x8x1x128xf32> + %3174 = stablehlo.concatenate %arg75, %3173, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %3175 = stablehlo.concatenate %arg76, %3145, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %3176 = stablehlo.reshape %3174 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %3177 = stablehlo.broadcast_in_dim %3176, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %3178 = stablehlo.reshape %3177 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %3179 = stablehlo.reshape %3175 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %3180 = stablehlo.broadcast_in_dim %3179, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %3181 = stablehlo.reshape %3180 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %3182 = stablehlo.transpose %3178, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %3183 = stablehlo.reshape %3162 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %3184 = stablehlo.reshape %3182 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %3185 = stablehlo.broadcast_in_dim %3184, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %3186 = stablehlo.dot_general %3183, %3185, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %3187 = stablehlo.reshape %3186 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %3188 = stablehlo.broadcast_in_dim %3187, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %3189 = stablehlo.divide %3188, %89 : tensor<3x32x1x8xf32> + %3190 = stablehlo.custom_call @byteir.softmax(%3189) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %3191 = stablehlo.reshape %3190 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %3192 = stablehlo.reshape %3181 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %3193 = stablehlo.broadcast_in_dim %3192, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %3194 = stablehlo.dot_general %3191, %3193, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %3195 = stablehlo.reshape %3194 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %3196 = stablehlo.transpose %3195, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %3197 = stablehlo.reshape %3196 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %3198 = stablehlo.reshape %3197 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %3199 = stablehlo.dot %3198, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %3200 = stablehlo.reshape %3199 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %3201 = stablehlo.add %3123, %3200 : tensor<3x1x4096xf32> + %3202 = stablehlo.broadcast_in_dim %3201, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %3203 = stablehlo.power %3202, %15 : tensor<3x1x4096xf32> + %3204 = stablehlo.reduce(%3203 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %3205 = stablehlo.reshape %3204 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %3206 = stablehlo.broadcast_in_dim %3205, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %3207 = stablehlo.divide %3206, %21 : tensor<3x1x1xf32> + %3208 = stablehlo.broadcast_in_dim %3207, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %3209 = stablehlo.add %3208, %25 : tensor<3x1x1xf32> + %3210 = stablehlo.rsqrt %3209 : tensor<3x1x1xf32> + %3211 = stablehlo.broadcast_in_dim %3210, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %3212 = stablehlo.multiply %3202, %3211 : tensor<3x1x4096xf32> + %3213 = stablehlo.broadcast_in_dim %3212, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %3214 = stablehlo.multiply %3213, %31 : tensor<3x1x4096xf32> + %3215 = stablehlo.reshape %3214 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %3216 = stablehlo.dot %3215, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %3217 = stablehlo.custom_call @byteir.softmax(%3216) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %3218:2 = stablehlo.custom_call @byteir.top_k(%3217) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %3219 = stablehlo.reduce(%3218#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %3220 = stablehlo.reshape %3219 : (tensor<3xf32>) -> tensor<3x1xf32> + %3221 = stablehlo.broadcast_in_dim %3218#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %3222 = stablehlo.broadcast_in_dim %3220, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %3223 = stablehlo.divide %3221, %3222 : tensor<3x2xf32> + %3224 = stablehlo.reshape %3218#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %3225 = stablehlo.broadcast_in_dim %3224, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %3226 = stablehlo.compare EQ, %3225, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %3227 = stablehlo.convert %3226 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %3228 = stablehlo.transpose %3227, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %3229 = stablehlo.slice %3228 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3230 = stablehlo.reshape %3229 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3231 = stablehlo.custom_call @byteir.non_zero(%3230) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1104 = tensor.dim %3231, %c0 : tensor + %3232 = arith.index_cast %dim_1104 : index to i64 + %from_elements_1105 = tensor.from_elements %3232, %c1_i64 : tensor<2xi64> + %3233 = stablehlo.real_dynamic_slice %3231, %c_22, %from_elements_1105, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1106 = tensor.dim %3233, %c0 : tensor + %3234 = arith.index_cast %dim_1106 : index to i64 + %from_elements_1107 = tensor.from_elements %3234 : tensor<1xi64> + %3235 = stablehlo.dynamic_reshape %3233, %from_elements_1107 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1108 = tensor.from_elements %3232, %c2_i64 : tensor<2xi64> + %3236 = stablehlo.real_dynamic_slice %3231, %c_24, %from_elements_1108, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1109 = tensor.dim %3236, %c0 : tensor + %3237 = arith.index_cast %dim_1109 : index to i64 + %from_elements_1110 = tensor.from_elements %3237 : tensor<1xi64> + %3238 = stablehlo.dynamic_reshape %3236, %from_elements_1110 : (tensor, tensor<1xi64>) -> tensor + %3239 = stablehlo.reshape %3215 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_1111 = tensor.dim %3238, %c0 : tensor + %3240 = arith.index_cast %dim_1111 : index to i64 + %from_elements_1112 = tensor.from_elements %3240, %c1_i64 : tensor<2xi64> + %3241 = stablehlo.dynamic_reshape %3238, %from_elements_1112 : (tensor, tensor<2xi64>) -> tensor + %dim_1113 = tensor.dim %3241, %c0 : tensor + %3242 = arith.index_cast %dim_1113 : index to i64 + %from_elements_1114 = tensor.from_elements %c1_i64, %3242, %c4096_i64 : tensor<3xi64> + %3243 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1114, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1115 = tensor.dim %3243, %c1 : tensor<1x?x4096xi64> + %3244 = arith.index_cast %dim_1115 : index to i64 + %from_elements_1116 = tensor.from_elements %c1_i64, %3244, %c4096_i64, %c1_i64 : tensor<4xi64> + %3245 = stablehlo.dynamic_reshape %3243, %from_elements_1116 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3246 = stablehlo.dynamic_broadcast_in_dim %3241, %from_elements_1114, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1117 = tensor.dim %3246, %c1 : tensor<1x?x4096xi64> + %3247 = arith.index_cast %dim_1117 : index to i64 + %from_elements_1118 = tensor.from_elements %c1_i64, %3247, %c4096_i64, %c1_i64 : tensor<4xi64> + %3248 = stablehlo.dynamic_reshape %3246, %from_elements_1118 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3249 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1114, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1119 = tensor.dim %3249, %c1 : tensor<1x?x4096xi64> + %3250 = arith.index_cast %dim_1119 : index to i64 + %from_elements_1120 = tensor.from_elements %c1_i64, %3250, %c4096_i64, %c1_i64 : tensor<4xi64> + %3251 = stablehlo.dynamic_reshape %3249, %from_elements_1120 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3252 = stablehlo.concatenate %3245, %3248, %3251, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3253 = "stablehlo.gather"(%3239, %3252) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3254 = shape.shape_of %3253 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3255 = shape.num_elements %3254 : tensor<3xindex> -> index + %3256 = stablehlo.compute_reshape_shape %3255, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3257 = stablehlo.dynamic_reshape %3253, %3256 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3258 = stablehlo.dot %3257, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3259 = stablehlo.logistic %3258 : tensor + %3260 = shape.shape_of %3259 : tensor -> tensor<2xindex> + %3261 = shape.shape_of %3258 : tensor -> tensor<2xindex> + %3262 = shape.cstr_broadcastable %3260, %3261 : tensor<2xindex>, tensor<2xindex> + %3263 = shape.assuming %3262 -> (tensor) { + %19688 = shape.broadcast %3260, %3261 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3259, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3258, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3264 = shape.shape_of %3263 : tensor -> tensor<2xindex> + %3265 = shape.cstr_broadcastable %3264, %3261 : tensor<2xindex>, tensor<2xindex> + %3266 = shape.assuming %3265 -> (tensor) { + %19688 = shape.broadcast %3264, %3261 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3263, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3258, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3267 = stablehlo.dot %3266, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %3268 = stablehlo.reshape %3223 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_1121 = tensor.dim %3238, %c0 : tensor + %3269 = arith.index_cast %dim_1121 : index to i64 + %from_elements_1122 = tensor.from_elements %3269, %c1_i64 : tensor<2xi64> + %3270 = stablehlo.dynamic_reshape %3238, %from_elements_1122 : (tensor, tensor<2xi64>) -> tensor + %dim_1123 = tensor.dim %3235, %c0 : tensor + %3271 = arith.index_cast %dim_1123 : index to i64 + %from_elements_1124 = tensor.from_elements %3271, %c1_i64 : tensor<2xi64> + %3272 = stablehlo.dynamic_reshape %3235, %from_elements_1124 : (tensor, tensor<2xi64>) -> tensor + %3273 = stablehlo.concatenate %3270, %3272, dim = 1 : (tensor, tensor) -> tensor + %3274 = "stablehlo.gather"(%3268, %3273) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3275 = shape.shape_of %3267 : tensor -> tensor<2xindex> + %3276 = shape.shape_of %3274 : tensor -> tensor<2xindex> + %3277 = shape.cstr_broadcastable %3275, %3276 : tensor<2xindex>, tensor<2xindex> + %3278 = shape.assuming %3277 -> (tensor) { + %19688 = shape.broadcast %3275, %3276 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3267, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3274, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3279 = shape.shape_of %3278 : tensor -> tensor<2xindex> + %3280 = stablehlo.dynamic_broadcast_in_dim %3278, %3279, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3281 = stablehlo.dynamic_broadcast_in_dim %213, %3279, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3282 = stablehlo.multiply %3280, %3281 : tensor + %dim_1125 = tensor.dim %3241, %c0 : tensor + %3283 = arith.index_cast %dim_1125 : index to i64 + %dim_1126 = tensor.dim %3278, %c0 : tensor + %3284 = arith.index_cast %dim_1126 : index to i64 + %3285 = arith.maxsi %3283, %3284 : i64 + %3286 = arith.index_cast %3285 : i64 to index + %from_elements_1127 = tensor.from_elements %3286, %c4096 : tensor<2xindex> + %3287 = stablehlo.dynamic_broadcast_in_dim %3241, %from_elements_1127, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1128 = tensor.dim %3287, %c0 : tensor + %3288 = arith.index_cast %dim_1128 : index to i64 + %from_elements_1129 = tensor.from_elements %3288, %c4096_i64 : tensor<2xi64> + %3289 = stablehlo.real_dynamic_slice %3282, %c_22, %from_elements_1129, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1130 = tensor.from_elements %3288, %c4096_i64, %c1_i64 : tensor<3xi64> + %3290 = stablehlo.dynamic_reshape %3287, %from_elements_1130 : (tensor, tensor<3xi64>) -> tensor + %3291 = stablehlo.dynamic_iota %from_elements_1130, dim = 1 : (tensor<3xi64>) -> tensor + %3292 = stablehlo.concatenate %3290, %3291, dim = 2 : (tensor, tensor) -> tensor + %3293 = "stablehlo.scatter"(%cst_2, %3292, %3289) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3294 = stablehlo.slice %3228 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3295 = stablehlo.reshape %3294 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3296 = stablehlo.custom_call @byteir.non_zero(%3295) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1131 = tensor.dim %3296, %c0 : tensor + %3297 = arith.index_cast %dim_1131 : index to i64 + %from_elements_1132 = tensor.from_elements %3297, %c1_i64 : tensor<2xi64> + %3298 = stablehlo.real_dynamic_slice %3296, %c_22, %from_elements_1132, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1133 = tensor.dim %3298, %c0 : tensor + %3299 = arith.index_cast %dim_1133 : index to i64 + %from_elements_1134 = tensor.from_elements %3299 : tensor<1xi64> + %3300 = stablehlo.dynamic_reshape %3298, %from_elements_1134 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1135 = tensor.from_elements %3297, %c2_i64 : tensor<2xi64> + %3301 = stablehlo.real_dynamic_slice %3296, %c_24, %from_elements_1135, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1136 = tensor.dim %3301, %c0 : tensor + %3302 = arith.index_cast %dim_1136 : index to i64 + %from_elements_1137 = tensor.from_elements %3302 : tensor<1xi64> + %3303 = stablehlo.dynamic_reshape %3301, %from_elements_1137 : (tensor, tensor<1xi64>) -> tensor + %dim_1138 = tensor.dim %3303, %c0 : tensor + %3304 = arith.index_cast %dim_1138 : index to i64 + %from_elements_1139 = tensor.from_elements %3304, %c1_i64 : tensor<2xi64> + %3305 = stablehlo.dynamic_reshape %3303, %from_elements_1139 : (tensor, tensor<2xi64>) -> tensor + %dim_1140 = tensor.dim %3305, %c0 : tensor + %3306 = arith.index_cast %dim_1140 : index to i64 + %from_elements_1141 = tensor.from_elements %c1_i64, %3306, %c4096_i64 : tensor<3xi64> + %3307 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1141, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1142 = tensor.dim %3307, %c1 : tensor<1x?x4096xi64> + %3308 = arith.index_cast %dim_1142 : index to i64 + %from_elements_1143 = tensor.from_elements %c1_i64, %3308, %c4096_i64, %c1_i64 : tensor<4xi64> + %3309 = stablehlo.dynamic_reshape %3307, %from_elements_1143 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3310 = stablehlo.dynamic_broadcast_in_dim %3305, %from_elements_1141, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1144 = tensor.dim %3310, %c1 : tensor<1x?x4096xi64> + %3311 = arith.index_cast %dim_1144 : index to i64 + %from_elements_1145 = tensor.from_elements %c1_i64, %3311, %c4096_i64, %c1_i64 : tensor<4xi64> + %3312 = stablehlo.dynamic_reshape %3310, %from_elements_1145 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3313 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1141, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1146 = tensor.dim %3313, %c1 : tensor<1x?x4096xi64> + %3314 = arith.index_cast %dim_1146 : index to i64 + %from_elements_1147 = tensor.from_elements %c1_i64, %3314, %c4096_i64, %c1_i64 : tensor<4xi64> + %3315 = stablehlo.dynamic_reshape %3313, %from_elements_1147 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3316 = stablehlo.concatenate %3309, %3312, %3315, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3317 = "stablehlo.gather"(%3239, %3316) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3318 = shape.shape_of %3317 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3319 = shape.num_elements %3318 : tensor<3xindex> -> index + %3320 = stablehlo.compute_reshape_shape %3319, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3321 = stablehlo.dynamic_reshape %3317, %3320 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3322 = stablehlo.dot %3321, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3323 = stablehlo.logistic %3322 : tensor + %3324 = shape.shape_of %3323 : tensor -> tensor<2xindex> + %3325 = shape.shape_of %3322 : tensor -> tensor<2xindex> + %3326 = shape.cstr_broadcastable %3324, %3325 : tensor<2xindex>, tensor<2xindex> + %3327 = shape.assuming %3326 -> (tensor) { + %19688 = shape.broadcast %3324, %3325 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3323, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3322, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3328 = shape.shape_of %3327 : tensor -> tensor<2xindex> + %3329 = shape.cstr_broadcastable %3328, %3325 : tensor<2xindex>, tensor<2xindex> + %3330 = shape.assuming %3329 -> (tensor) { + %19688 = shape.broadcast %3328, %3325 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3327, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3322, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3331 = stablehlo.dot %3330, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1148 = tensor.dim %3303, %c0 : tensor + %3332 = arith.index_cast %dim_1148 : index to i64 + %from_elements_1149 = tensor.from_elements %3332, %c1_i64 : tensor<2xi64> + %3333 = stablehlo.dynamic_reshape %3303, %from_elements_1149 : (tensor, tensor<2xi64>) -> tensor + %dim_1150 = tensor.dim %3300, %c0 : tensor + %3334 = arith.index_cast %dim_1150 : index to i64 + %from_elements_1151 = tensor.from_elements %3334, %c1_i64 : tensor<2xi64> + %3335 = stablehlo.dynamic_reshape %3300, %from_elements_1151 : (tensor, tensor<2xi64>) -> tensor + %3336 = stablehlo.concatenate %3333, %3335, dim = 1 : (tensor, tensor) -> tensor + %3337 = "stablehlo.gather"(%3268, %3336) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3338 = shape.shape_of %3331 : tensor -> tensor<2xindex> + %3339 = shape.shape_of %3337 : tensor -> tensor<2xindex> + %3340 = shape.cstr_broadcastable %3338, %3339 : tensor<2xindex>, tensor<2xindex> + %3341 = shape.assuming %3340 -> (tensor) { + %19688 = shape.broadcast %3338, %3339 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3331, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3337, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3342 = shape.shape_of %3341 : tensor -> tensor<2xindex> + %3343 = stablehlo.dynamic_broadcast_in_dim %3341, %3342, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3344 = stablehlo.dynamic_broadcast_in_dim %213, %3342, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3345 = stablehlo.multiply %3343, %3344 : tensor + %dim_1152 = tensor.dim %3305, %c0 : tensor + %3346 = arith.index_cast %dim_1152 : index to i64 + %dim_1153 = tensor.dim %3341, %c0 : tensor + %3347 = arith.index_cast %dim_1153 : index to i64 + %3348 = arith.maxsi %3346, %3347 : i64 + %3349 = arith.index_cast %3348 : i64 to index + %from_elements_1154 = tensor.from_elements %3349, %c4096 : tensor<2xindex> + %3350 = stablehlo.dynamic_broadcast_in_dim %3305, %from_elements_1154, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1155 = tensor.dim %3350, %c0 : tensor + %3351 = arith.index_cast %dim_1155 : index to i64 + %from_elements_1156 = tensor.from_elements %3351, %c4096_i64 : tensor<2xi64> + %3352 = stablehlo.real_dynamic_slice %3345, %c_22, %from_elements_1156, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1157 = tensor.from_elements %3351, %c4096_i64, %c1_i64 : tensor<3xi64> + %3353 = stablehlo.dynamic_reshape %3350, %from_elements_1157 : (tensor, tensor<3xi64>) -> tensor + %3354 = stablehlo.dynamic_iota %from_elements_1157, dim = 1 : (tensor<3xi64>) -> tensor + %3355 = stablehlo.concatenate %3353, %3354, dim = 2 : (tensor, tensor) -> tensor + %3356 = "stablehlo.scatter"(%3293, %3355, %3352) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3357 = stablehlo.slice %3228 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3358 = stablehlo.reshape %3357 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3359 = stablehlo.custom_call @byteir.non_zero(%3358) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1158 = tensor.dim %3359, %c0 : tensor + %3360 = arith.index_cast %dim_1158 : index to i64 + %from_elements_1159 = tensor.from_elements %3360, %c1_i64 : tensor<2xi64> + %3361 = stablehlo.real_dynamic_slice %3359, %c_22, %from_elements_1159, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1160 = tensor.dim %3361, %c0 : tensor + %3362 = arith.index_cast %dim_1160 : index to i64 + %from_elements_1161 = tensor.from_elements %3362 : tensor<1xi64> + %3363 = stablehlo.dynamic_reshape %3361, %from_elements_1161 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1162 = tensor.from_elements %3360, %c2_i64 : tensor<2xi64> + %3364 = stablehlo.real_dynamic_slice %3359, %c_24, %from_elements_1162, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1163 = tensor.dim %3364, %c0 : tensor + %3365 = arith.index_cast %dim_1163 : index to i64 + %from_elements_1164 = tensor.from_elements %3365 : tensor<1xi64> + %3366 = stablehlo.dynamic_reshape %3364, %from_elements_1164 : (tensor, tensor<1xi64>) -> tensor + %dim_1165 = tensor.dim %3366, %c0 : tensor + %3367 = arith.index_cast %dim_1165 : index to i64 + %from_elements_1166 = tensor.from_elements %3367, %c1_i64 : tensor<2xi64> + %3368 = stablehlo.dynamic_reshape %3366, %from_elements_1166 : (tensor, tensor<2xi64>) -> tensor + %dim_1167 = tensor.dim %3368, %c0 : tensor + %3369 = arith.index_cast %dim_1167 : index to i64 + %from_elements_1168 = tensor.from_elements %c1_i64, %3369, %c4096_i64 : tensor<3xi64> + %3370 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1168, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1169 = tensor.dim %3370, %c1 : tensor<1x?x4096xi64> + %3371 = arith.index_cast %dim_1169 : index to i64 + %from_elements_1170 = tensor.from_elements %c1_i64, %3371, %c4096_i64, %c1_i64 : tensor<4xi64> + %3372 = stablehlo.dynamic_reshape %3370, %from_elements_1170 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3373 = stablehlo.dynamic_broadcast_in_dim %3368, %from_elements_1168, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1171 = tensor.dim %3373, %c1 : tensor<1x?x4096xi64> + %3374 = arith.index_cast %dim_1171 : index to i64 + %from_elements_1172 = tensor.from_elements %c1_i64, %3374, %c4096_i64, %c1_i64 : tensor<4xi64> + %3375 = stablehlo.dynamic_reshape %3373, %from_elements_1172 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3376 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1168, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1173 = tensor.dim %3376, %c1 : tensor<1x?x4096xi64> + %3377 = arith.index_cast %dim_1173 : index to i64 + %from_elements_1174 = tensor.from_elements %c1_i64, %3377, %c4096_i64, %c1_i64 : tensor<4xi64> + %3378 = stablehlo.dynamic_reshape %3376, %from_elements_1174 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3379 = stablehlo.concatenate %3372, %3375, %3378, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3380 = "stablehlo.gather"(%3239, %3379) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3381 = shape.shape_of %3380 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3382 = shape.num_elements %3381 : tensor<3xindex> -> index + %3383 = stablehlo.compute_reshape_shape %3382, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3384 = stablehlo.dynamic_reshape %3380, %3383 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3385 = stablehlo.dot %3384, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3386 = stablehlo.logistic %3385 : tensor + %3387 = shape.shape_of %3386 : tensor -> tensor<2xindex> + %3388 = shape.shape_of %3385 : tensor -> tensor<2xindex> + %3389 = shape.cstr_broadcastable %3387, %3388 : tensor<2xindex>, tensor<2xindex> + %3390 = shape.assuming %3389 -> (tensor) { + %19688 = shape.broadcast %3387, %3388 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3386, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3385, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3391 = shape.shape_of %3390 : tensor -> tensor<2xindex> + %3392 = shape.cstr_broadcastable %3391, %3388 : tensor<2xindex>, tensor<2xindex> + %3393 = shape.assuming %3392 -> (tensor) { + %19688 = shape.broadcast %3391, %3388 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3390, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3385, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3394 = stablehlo.dot %3393, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1175 = tensor.dim %3366, %c0 : tensor + %3395 = arith.index_cast %dim_1175 : index to i64 + %from_elements_1176 = tensor.from_elements %3395, %c1_i64 : tensor<2xi64> + %3396 = stablehlo.dynamic_reshape %3366, %from_elements_1176 : (tensor, tensor<2xi64>) -> tensor + %dim_1177 = tensor.dim %3363, %c0 : tensor + %3397 = arith.index_cast %dim_1177 : index to i64 + %from_elements_1178 = tensor.from_elements %3397, %c1_i64 : tensor<2xi64> + %3398 = stablehlo.dynamic_reshape %3363, %from_elements_1178 : (tensor, tensor<2xi64>) -> tensor + %3399 = stablehlo.concatenate %3396, %3398, dim = 1 : (tensor, tensor) -> tensor + %3400 = "stablehlo.gather"(%3268, %3399) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3401 = shape.shape_of %3394 : tensor -> tensor<2xindex> + %3402 = shape.shape_of %3400 : tensor -> tensor<2xindex> + %3403 = shape.cstr_broadcastable %3401, %3402 : tensor<2xindex>, tensor<2xindex> + %3404 = shape.assuming %3403 -> (tensor) { + %19688 = shape.broadcast %3401, %3402 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3394, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3400, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3405 = shape.shape_of %3404 : tensor -> tensor<2xindex> + %3406 = stablehlo.dynamic_broadcast_in_dim %3404, %3405, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3407 = stablehlo.dynamic_broadcast_in_dim %213, %3405, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3408 = stablehlo.multiply %3406, %3407 : tensor + %dim_1179 = tensor.dim %3368, %c0 : tensor + %3409 = arith.index_cast %dim_1179 : index to i64 + %dim_1180 = tensor.dim %3404, %c0 : tensor + %3410 = arith.index_cast %dim_1180 : index to i64 + %3411 = arith.maxsi %3409, %3410 : i64 + %3412 = arith.index_cast %3411 : i64 to index + %from_elements_1181 = tensor.from_elements %3412, %c4096 : tensor<2xindex> + %3413 = stablehlo.dynamic_broadcast_in_dim %3368, %from_elements_1181, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1182 = tensor.dim %3413, %c0 : tensor + %3414 = arith.index_cast %dim_1182 : index to i64 + %from_elements_1183 = tensor.from_elements %3414, %c4096_i64 : tensor<2xi64> + %3415 = stablehlo.real_dynamic_slice %3408, %c_22, %from_elements_1183, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1184 = tensor.from_elements %3414, %c4096_i64, %c1_i64 : tensor<3xi64> + %3416 = stablehlo.dynamic_reshape %3413, %from_elements_1184 : (tensor, tensor<3xi64>) -> tensor + %3417 = stablehlo.dynamic_iota %from_elements_1184, dim = 1 : (tensor<3xi64>) -> tensor + %3418 = stablehlo.concatenate %3416, %3417, dim = 2 : (tensor, tensor) -> tensor + %3419 = "stablehlo.scatter"(%3356, %3418, %3415) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3420 = stablehlo.slice %3228 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3421 = stablehlo.reshape %3420 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3422 = stablehlo.custom_call @byteir.non_zero(%3421) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1185 = tensor.dim %3422, %c0 : tensor + %3423 = arith.index_cast %dim_1185 : index to i64 + %from_elements_1186 = tensor.from_elements %3423, %c1_i64 : tensor<2xi64> + %3424 = stablehlo.real_dynamic_slice %3422, %c_22, %from_elements_1186, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1187 = tensor.dim %3424, %c0 : tensor + %3425 = arith.index_cast %dim_1187 : index to i64 + %from_elements_1188 = tensor.from_elements %3425 : tensor<1xi64> + %3426 = stablehlo.dynamic_reshape %3424, %from_elements_1188 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1189 = tensor.from_elements %3423, %c2_i64 : tensor<2xi64> + %3427 = stablehlo.real_dynamic_slice %3422, %c_24, %from_elements_1189, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1190 = tensor.dim %3427, %c0 : tensor + %3428 = arith.index_cast %dim_1190 : index to i64 + %from_elements_1191 = tensor.from_elements %3428 : tensor<1xi64> + %3429 = stablehlo.dynamic_reshape %3427, %from_elements_1191 : (tensor, tensor<1xi64>) -> tensor + %dim_1192 = tensor.dim %3429, %c0 : tensor + %3430 = arith.index_cast %dim_1192 : index to i64 + %from_elements_1193 = tensor.from_elements %3430, %c1_i64 : tensor<2xi64> + %3431 = stablehlo.dynamic_reshape %3429, %from_elements_1193 : (tensor, tensor<2xi64>) -> tensor + %dim_1194 = tensor.dim %3431, %c0 : tensor + %3432 = arith.index_cast %dim_1194 : index to i64 + %from_elements_1195 = tensor.from_elements %c1_i64, %3432, %c4096_i64 : tensor<3xi64> + %3433 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1195, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1196 = tensor.dim %3433, %c1 : tensor<1x?x4096xi64> + %3434 = arith.index_cast %dim_1196 : index to i64 + %from_elements_1197 = tensor.from_elements %c1_i64, %3434, %c4096_i64, %c1_i64 : tensor<4xi64> + %3435 = stablehlo.dynamic_reshape %3433, %from_elements_1197 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3436 = stablehlo.dynamic_broadcast_in_dim %3431, %from_elements_1195, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1198 = tensor.dim %3436, %c1 : tensor<1x?x4096xi64> + %3437 = arith.index_cast %dim_1198 : index to i64 + %from_elements_1199 = tensor.from_elements %c1_i64, %3437, %c4096_i64, %c1_i64 : tensor<4xi64> + %3438 = stablehlo.dynamic_reshape %3436, %from_elements_1199 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3439 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1195, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1200 = tensor.dim %3439, %c1 : tensor<1x?x4096xi64> + %3440 = arith.index_cast %dim_1200 : index to i64 + %from_elements_1201 = tensor.from_elements %c1_i64, %3440, %c4096_i64, %c1_i64 : tensor<4xi64> + %3441 = stablehlo.dynamic_reshape %3439, %from_elements_1201 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3442 = stablehlo.concatenate %3435, %3438, %3441, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3443 = "stablehlo.gather"(%3239, %3442) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3444 = shape.shape_of %3443 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3445 = shape.num_elements %3444 : tensor<3xindex> -> index + %3446 = stablehlo.compute_reshape_shape %3445, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3447 = stablehlo.dynamic_reshape %3443, %3446 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3448 = stablehlo.dot %3447, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3449 = stablehlo.logistic %3448 : tensor + %3450 = shape.shape_of %3449 : tensor -> tensor<2xindex> + %3451 = shape.shape_of %3448 : tensor -> tensor<2xindex> + %3452 = shape.cstr_broadcastable %3450, %3451 : tensor<2xindex>, tensor<2xindex> + %3453 = shape.assuming %3452 -> (tensor) { + %19688 = shape.broadcast %3450, %3451 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3449, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3448, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3454 = shape.shape_of %3453 : tensor -> tensor<2xindex> + %3455 = shape.cstr_broadcastable %3454, %3451 : tensor<2xindex>, tensor<2xindex> + %3456 = shape.assuming %3455 -> (tensor) { + %19688 = shape.broadcast %3454, %3451 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3453, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3448, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3457 = stablehlo.dot %3456, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1202 = tensor.dim %3429, %c0 : tensor + %3458 = arith.index_cast %dim_1202 : index to i64 + %from_elements_1203 = tensor.from_elements %3458, %c1_i64 : tensor<2xi64> + %3459 = stablehlo.dynamic_reshape %3429, %from_elements_1203 : (tensor, tensor<2xi64>) -> tensor + %dim_1204 = tensor.dim %3426, %c0 : tensor + %3460 = arith.index_cast %dim_1204 : index to i64 + %from_elements_1205 = tensor.from_elements %3460, %c1_i64 : tensor<2xi64> + %3461 = stablehlo.dynamic_reshape %3426, %from_elements_1205 : (tensor, tensor<2xi64>) -> tensor + %3462 = stablehlo.concatenate %3459, %3461, dim = 1 : (tensor, tensor) -> tensor + %3463 = "stablehlo.gather"(%3268, %3462) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3464 = shape.shape_of %3457 : tensor -> tensor<2xindex> + %3465 = shape.shape_of %3463 : tensor -> tensor<2xindex> + %3466 = shape.cstr_broadcastable %3464, %3465 : tensor<2xindex>, tensor<2xindex> + %3467 = shape.assuming %3466 -> (tensor) { + %19688 = shape.broadcast %3464, %3465 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3457, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3463, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3468 = shape.shape_of %3467 : tensor -> tensor<2xindex> + %3469 = stablehlo.dynamic_broadcast_in_dim %3467, %3468, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3470 = stablehlo.dynamic_broadcast_in_dim %213, %3468, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3471 = stablehlo.multiply %3469, %3470 : tensor + %dim_1206 = tensor.dim %3431, %c0 : tensor + %3472 = arith.index_cast %dim_1206 : index to i64 + %dim_1207 = tensor.dim %3467, %c0 : tensor + %3473 = arith.index_cast %dim_1207 : index to i64 + %3474 = arith.maxsi %3472, %3473 : i64 + %3475 = arith.index_cast %3474 : i64 to index + %from_elements_1208 = tensor.from_elements %3475, %c4096 : tensor<2xindex> + %3476 = stablehlo.dynamic_broadcast_in_dim %3431, %from_elements_1208, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1209 = tensor.dim %3476, %c0 : tensor + %3477 = arith.index_cast %dim_1209 : index to i64 + %from_elements_1210 = tensor.from_elements %3477, %c4096_i64 : tensor<2xi64> + %3478 = stablehlo.real_dynamic_slice %3471, %c_22, %from_elements_1210, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1211 = tensor.from_elements %3477, %c4096_i64, %c1_i64 : tensor<3xi64> + %3479 = stablehlo.dynamic_reshape %3476, %from_elements_1211 : (tensor, tensor<3xi64>) -> tensor + %3480 = stablehlo.dynamic_iota %from_elements_1211, dim = 1 : (tensor<3xi64>) -> tensor + %3481 = stablehlo.concatenate %3479, %3480, dim = 2 : (tensor, tensor) -> tensor + %3482 = "stablehlo.scatter"(%3419, %3481, %3478) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3483 = stablehlo.slice %3228 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3484 = stablehlo.reshape %3483 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3485 = stablehlo.custom_call @byteir.non_zero(%3484) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1212 = tensor.dim %3485, %c0 : tensor + %3486 = arith.index_cast %dim_1212 : index to i64 + %from_elements_1213 = tensor.from_elements %3486, %c1_i64 : tensor<2xi64> + %3487 = stablehlo.real_dynamic_slice %3485, %c_22, %from_elements_1213, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1214 = tensor.dim %3487, %c0 : tensor + %3488 = arith.index_cast %dim_1214 : index to i64 + %from_elements_1215 = tensor.from_elements %3488 : tensor<1xi64> + %3489 = stablehlo.dynamic_reshape %3487, %from_elements_1215 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1216 = tensor.from_elements %3486, %c2_i64 : tensor<2xi64> + %3490 = stablehlo.real_dynamic_slice %3485, %c_24, %from_elements_1216, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1217 = tensor.dim %3490, %c0 : tensor + %3491 = arith.index_cast %dim_1217 : index to i64 + %from_elements_1218 = tensor.from_elements %3491 : tensor<1xi64> + %3492 = stablehlo.dynamic_reshape %3490, %from_elements_1218 : (tensor, tensor<1xi64>) -> tensor + %dim_1219 = tensor.dim %3492, %c0 : tensor + %3493 = arith.index_cast %dim_1219 : index to i64 + %from_elements_1220 = tensor.from_elements %3493, %c1_i64 : tensor<2xi64> + %3494 = stablehlo.dynamic_reshape %3492, %from_elements_1220 : (tensor, tensor<2xi64>) -> tensor + %dim_1221 = tensor.dim %3494, %c0 : tensor + %3495 = arith.index_cast %dim_1221 : index to i64 + %from_elements_1222 = tensor.from_elements %c1_i64, %3495, %c4096_i64 : tensor<3xi64> + %3496 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1222, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1223 = tensor.dim %3496, %c1 : tensor<1x?x4096xi64> + %3497 = arith.index_cast %dim_1223 : index to i64 + %from_elements_1224 = tensor.from_elements %c1_i64, %3497, %c4096_i64, %c1_i64 : tensor<4xi64> + %3498 = stablehlo.dynamic_reshape %3496, %from_elements_1224 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3499 = stablehlo.dynamic_broadcast_in_dim %3494, %from_elements_1222, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1225 = tensor.dim %3499, %c1 : tensor<1x?x4096xi64> + %3500 = arith.index_cast %dim_1225 : index to i64 + %from_elements_1226 = tensor.from_elements %c1_i64, %3500, %c4096_i64, %c1_i64 : tensor<4xi64> + %3501 = stablehlo.dynamic_reshape %3499, %from_elements_1226 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3502 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1222, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1227 = tensor.dim %3502, %c1 : tensor<1x?x4096xi64> + %3503 = arith.index_cast %dim_1227 : index to i64 + %from_elements_1228 = tensor.from_elements %c1_i64, %3503, %c4096_i64, %c1_i64 : tensor<4xi64> + %3504 = stablehlo.dynamic_reshape %3502, %from_elements_1228 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3505 = stablehlo.concatenate %3498, %3501, %3504, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3506 = "stablehlo.gather"(%3239, %3505) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3507 = shape.shape_of %3506 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3508 = shape.num_elements %3507 : tensor<3xindex> -> index + %3509 = stablehlo.compute_reshape_shape %3508, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3510 = stablehlo.dynamic_reshape %3506, %3509 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3511 = stablehlo.dot %3510, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3512 = stablehlo.logistic %3511 : tensor + %3513 = shape.shape_of %3512 : tensor -> tensor<2xindex> + %3514 = shape.shape_of %3511 : tensor -> tensor<2xindex> + %3515 = shape.cstr_broadcastable %3513, %3514 : tensor<2xindex>, tensor<2xindex> + %3516 = shape.assuming %3515 -> (tensor) { + %19688 = shape.broadcast %3513, %3514 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3512, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3511, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3517 = shape.shape_of %3516 : tensor -> tensor<2xindex> + %3518 = shape.cstr_broadcastable %3517, %3514 : tensor<2xindex>, tensor<2xindex> + %3519 = shape.assuming %3518 -> (tensor) { + %19688 = shape.broadcast %3517, %3514 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3516, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3511, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3520 = stablehlo.dot %3519, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1229 = tensor.dim %3492, %c0 : tensor + %3521 = arith.index_cast %dim_1229 : index to i64 + %from_elements_1230 = tensor.from_elements %3521, %c1_i64 : tensor<2xi64> + %3522 = stablehlo.dynamic_reshape %3492, %from_elements_1230 : (tensor, tensor<2xi64>) -> tensor + %dim_1231 = tensor.dim %3489, %c0 : tensor + %3523 = arith.index_cast %dim_1231 : index to i64 + %from_elements_1232 = tensor.from_elements %3523, %c1_i64 : tensor<2xi64> + %3524 = stablehlo.dynamic_reshape %3489, %from_elements_1232 : (tensor, tensor<2xi64>) -> tensor + %3525 = stablehlo.concatenate %3522, %3524, dim = 1 : (tensor, tensor) -> tensor + %3526 = "stablehlo.gather"(%3268, %3525) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3527 = shape.shape_of %3520 : tensor -> tensor<2xindex> + %3528 = shape.shape_of %3526 : tensor -> tensor<2xindex> + %3529 = shape.cstr_broadcastable %3527, %3528 : tensor<2xindex>, tensor<2xindex> + %3530 = shape.assuming %3529 -> (tensor) { + %19688 = shape.broadcast %3527, %3528 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3520, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3526, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3531 = shape.shape_of %3530 : tensor -> tensor<2xindex> + %3532 = stablehlo.dynamic_broadcast_in_dim %3530, %3531, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3533 = stablehlo.dynamic_broadcast_in_dim %213, %3531, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3534 = stablehlo.multiply %3532, %3533 : tensor + %dim_1233 = tensor.dim %3494, %c0 : tensor + %3535 = arith.index_cast %dim_1233 : index to i64 + %dim_1234 = tensor.dim %3530, %c0 : tensor + %3536 = arith.index_cast %dim_1234 : index to i64 + %3537 = arith.maxsi %3535, %3536 : i64 + %3538 = arith.index_cast %3537 : i64 to index + %from_elements_1235 = tensor.from_elements %3538, %c4096 : tensor<2xindex> + %3539 = stablehlo.dynamic_broadcast_in_dim %3494, %from_elements_1235, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1236 = tensor.dim %3539, %c0 : tensor + %3540 = arith.index_cast %dim_1236 : index to i64 + %from_elements_1237 = tensor.from_elements %3540, %c4096_i64 : tensor<2xi64> + %3541 = stablehlo.real_dynamic_slice %3534, %c_22, %from_elements_1237, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1238 = tensor.from_elements %3540, %c4096_i64, %c1_i64 : tensor<3xi64> + %3542 = stablehlo.dynamic_reshape %3539, %from_elements_1238 : (tensor, tensor<3xi64>) -> tensor + %3543 = stablehlo.dynamic_iota %from_elements_1238, dim = 1 : (tensor<3xi64>) -> tensor + %3544 = stablehlo.concatenate %3542, %3543, dim = 2 : (tensor, tensor) -> tensor + %3545 = "stablehlo.scatter"(%3482, %3544, %3541) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3546 = stablehlo.slice %3228 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3547 = stablehlo.reshape %3546 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3548 = stablehlo.custom_call @byteir.non_zero(%3547) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1239 = tensor.dim %3548, %c0 : tensor + %3549 = arith.index_cast %dim_1239 : index to i64 + %from_elements_1240 = tensor.from_elements %3549, %c1_i64 : tensor<2xi64> + %3550 = stablehlo.real_dynamic_slice %3548, %c_22, %from_elements_1240, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1241 = tensor.dim %3550, %c0 : tensor + %3551 = arith.index_cast %dim_1241 : index to i64 + %from_elements_1242 = tensor.from_elements %3551 : tensor<1xi64> + %3552 = stablehlo.dynamic_reshape %3550, %from_elements_1242 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1243 = tensor.from_elements %3549, %c2_i64 : tensor<2xi64> + %3553 = stablehlo.real_dynamic_slice %3548, %c_24, %from_elements_1243, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1244 = tensor.dim %3553, %c0 : tensor + %3554 = arith.index_cast %dim_1244 : index to i64 + %from_elements_1245 = tensor.from_elements %3554 : tensor<1xi64> + %3555 = stablehlo.dynamic_reshape %3553, %from_elements_1245 : (tensor, tensor<1xi64>) -> tensor + %dim_1246 = tensor.dim %3555, %c0 : tensor + %3556 = arith.index_cast %dim_1246 : index to i64 + %from_elements_1247 = tensor.from_elements %3556, %c1_i64 : tensor<2xi64> + %3557 = stablehlo.dynamic_reshape %3555, %from_elements_1247 : (tensor, tensor<2xi64>) -> tensor + %dim_1248 = tensor.dim %3557, %c0 : tensor + %3558 = arith.index_cast %dim_1248 : index to i64 + %from_elements_1249 = tensor.from_elements %c1_i64, %3558, %c4096_i64 : tensor<3xi64> + %3559 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1249, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1250 = tensor.dim %3559, %c1 : tensor<1x?x4096xi64> + %3560 = arith.index_cast %dim_1250 : index to i64 + %from_elements_1251 = tensor.from_elements %c1_i64, %3560, %c4096_i64, %c1_i64 : tensor<4xi64> + %3561 = stablehlo.dynamic_reshape %3559, %from_elements_1251 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3562 = stablehlo.dynamic_broadcast_in_dim %3557, %from_elements_1249, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1252 = tensor.dim %3562, %c1 : tensor<1x?x4096xi64> + %3563 = arith.index_cast %dim_1252 : index to i64 + %from_elements_1253 = tensor.from_elements %c1_i64, %3563, %c4096_i64, %c1_i64 : tensor<4xi64> + %3564 = stablehlo.dynamic_reshape %3562, %from_elements_1253 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3565 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1249, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1254 = tensor.dim %3565, %c1 : tensor<1x?x4096xi64> + %3566 = arith.index_cast %dim_1254 : index to i64 + %from_elements_1255 = tensor.from_elements %c1_i64, %3566, %c4096_i64, %c1_i64 : tensor<4xi64> + %3567 = stablehlo.dynamic_reshape %3565, %from_elements_1255 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3568 = stablehlo.concatenate %3561, %3564, %3567, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3569 = "stablehlo.gather"(%3239, %3568) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3570 = shape.shape_of %3569 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3571 = shape.num_elements %3570 : tensor<3xindex> -> index + %3572 = stablehlo.compute_reshape_shape %3571, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3573 = stablehlo.dynamic_reshape %3569, %3572 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3574 = stablehlo.dot %3573, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3575 = stablehlo.logistic %3574 : tensor + %3576 = shape.shape_of %3575 : tensor -> tensor<2xindex> + %3577 = shape.shape_of %3574 : tensor -> tensor<2xindex> + %3578 = shape.cstr_broadcastable %3576, %3577 : tensor<2xindex>, tensor<2xindex> + %3579 = shape.assuming %3578 -> (tensor) { + %19688 = shape.broadcast %3576, %3577 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3575, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3574, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3580 = shape.shape_of %3579 : tensor -> tensor<2xindex> + %3581 = shape.cstr_broadcastable %3580, %3577 : tensor<2xindex>, tensor<2xindex> + %3582 = shape.assuming %3581 -> (tensor) { + %19688 = shape.broadcast %3580, %3577 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3579, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3574, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3583 = stablehlo.dot %3582, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1256 = tensor.dim %3555, %c0 : tensor + %3584 = arith.index_cast %dim_1256 : index to i64 + %from_elements_1257 = tensor.from_elements %3584, %c1_i64 : tensor<2xi64> + %3585 = stablehlo.dynamic_reshape %3555, %from_elements_1257 : (tensor, tensor<2xi64>) -> tensor + %dim_1258 = tensor.dim %3552, %c0 : tensor + %3586 = arith.index_cast %dim_1258 : index to i64 + %from_elements_1259 = tensor.from_elements %3586, %c1_i64 : tensor<2xi64> + %3587 = stablehlo.dynamic_reshape %3552, %from_elements_1259 : (tensor, tensor<2xi64>) -> tensor + %3588 = stablehlo.concatenate %3585, %3587, dim = 1 : (tensor, tensor) -> tensor + %3589 = "stablehlo.gather"(%3268, %3588) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3590 = shape.shape_of %3583 : tensor -> tensor<2xindex> + %3591 = shape.shape_of %3589 : tensor -> tensor<2xindex> + %3592 = shape.cstr_broadcastable %3590, %3591 : tensor<2xindex>, tensor<2xindex> + %3593 = shape.assuming %3592 -> (tensor) { + %19688 = shape.broadcast %3590, %3591 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3583, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3589, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3594 = shape.shape_of %3593 : tensor -> tensor<2xindex> + %3595 = stablehlo.dynamic_broadcast_in_dim %3593, %3594, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3596 = stablehlo.dynamic_broadcast_in_dim %213, %3594, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3597 = stablehlo.multiply %3595, %3596 : tensor + %dim_1260 = tensor.dim %3557, %c0 : tensor + %3598 = arith.index_cast %dim_1260 : index to i64 + %dim_1261 = tensor.dim %3593, %c0 : tensor + %3599 = arith.index_cast %dim_1261 : index to i64 + %3600 = arith.maxsi %3598, %3599 : i64 + %3601 = arith.index_cast %3600 : i64 to index + %from_elements_1262 = tensor.from_elements %3601, %c4096 : tensor<2xindex> + %3602 = stablehlo.dynamic_broadcast_in_dim %3557, %from_elements_1262, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1263 = tensor.dim %3602, %c0 : tensor + %3603 = arith.index_cast %dim_1263 : index to i64 + %from_elements_1264 = tensor.from_elements %3603, %c4096_i64 : tensor<2xi64> + %3604 = stablehlo.real_dynamic_slice %3597, %c_22, %from_elements_1264, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1265 = tensor.from_elements %3603, %c4096_i64, %c1_i64 : tensor<3xi64> + %3605 = stablehlo.dynamic_reshape %3602, %from_elements_1265 : (tensor, tensor<3xi64>) -> tensor + %3606 = stablehlo.dynamic_iota %from_elements_1265, dim = 1 : (tensor<3xi64>) -> tensor + %3607 = stablehlo.concatenate %3605, %3606, dim = 2 : (tensor, tensor) -> tensor + %3608 = "stablehlo.scatter"(%3545, %3607, %3604) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3609 = stablehlo.slice %3228 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3610 = stablehlo.reshape %3609 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3611 = stablehlo.custom_call @byteir.non_zero(%3610) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1266 = tensor.dim %3611, %c0 : tensor + %3612 = arith.index_cast %dim_1266 : index to i64 + %from_elements_1267 = tensor.from_elements %3612, %c1_i64 : tensor<2xi64> + %3613 = stablehlo.real_dynamic_slice %3611, %c_22, %from_elements_1267, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1268 = tensor.dim %3613, %c0 : tensor + %3614 = arith.index_cast %dim_1268 : index to i64 + %from_elements_1269 = tensor.from_elements %3614 : tensor<1xi64> + %3615 = stablehlo.dynamic_reshape %3613, %from_elements_1269 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1270 = tensor.from_elements %3612, %c2_i64 : tensor<2xi64> + %3616 = stablehlo.real_dynamic_slice %3611, %c_24, %from_elements_1270, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1271 = tensor.dim %3616, %c0 : tensor + %3617 = arith.index_cast %dim_1271 : index to i64 + %from_elements_1272 = tensor.from_elements %3617 : tensor<1xi64> + %3618 = stablehlo.dynamic_reshape %3616, %from_elements_1272 : (tensor, tensor<1xi64>) -> tensor + %dim_1273 = tensor.dim %3618, %c0 : tensor + %3619 = arith.index_cast %dim_1273 : index to i64 + %from_elements_1274 = tensor.from_elements %3619, %c1_i64 : tensor<2xi64> + %3620 = stablehlo.dynamic_reshape %3618, %from_elements_1274 : (tensor, tensor<2xi64>) -> tensor + %dim_1275 = tensor.dim %3620, %c0 : tensor + %3621 = arith.index_cast %dim_1275 : index to i64 + %from_elements_1276 = tensor.from_elements %c1_i64, %3621, %c4096_i64 : tensor<3xi64> + %3622 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1276, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1277 = tensor.dim %3622, %c1 : tensor<1x?x4096xi64> + %3623 = arith.index_cast %dim_1277 : index to i64 + %from_elements_1278 = tensor.from_elements %c1_i64, %3623, %c4096_i64, %c1_i64 : tensor<4xi64> + %3624 = stablehlo.dynamic_reshape %3622, %from_elements_1278 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3625 = stablehlo.dynamic_broadcast_in_dim %3620, %from_elements_1276, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1279 = tensor.dim %3625, %c1 : tensor<1x?x4096xi64> + %3626 = arith.index_cast %dim_1279 : index to i64 + %from_elements_1280 = tensor.from_elements %c1_i64, %3626, %c4096_i64, %c1_i64 : tensor<4xi64> + %3627 = stablehlo.dynamic_reshape %3625, %from_elements_1280 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3628 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1276, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1281 = tensor.dim %3628, %c1 : tensor<1x?x4096xi64> + %3629 = arith.index_cast %dim_1281 : index to i64 + %from_elements_1282 = tensor.from_elements %c1_i64, %3629, %c4096_i64, %c1_i64 : tensor<4xi64> + %3630 = stablehlo.dynamic_reshape %3628, %from_elements_1282 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3631 = stablehlo.concatenate %3624, %3627, %3630, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3632 = "stablehlo.gather"(%3239, %3631) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3633 = shape.shape_of %3632 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3634 = shape.num_elements %3633 : tensor<3xindex> -> index + %3635 = stablehlo.compute_reshape_shape %3634, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3636 = stablehlo.dynamic_reshape %3632, %3635 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3637 = stablehlo.dot %3636, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3638 = stablehlo.logistic %3637 : tensor + %3639 = shape.shape_of %3638 : tensor -> tensor<2xindex> + %3640 = shape.shape_of %3637 : tensor -> tensor<2xindex> + %3641 = shape.cstr_broadcastable %3639, %3640 : tensor<2xindex>, tensor<2xindex> + %3642 = shape.assuming %3641 -> (tensor) { + %19688 = shape.broadcast %3639, %3640 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3638, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3637, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3643 = shape.shape_of %3642 : tensor -> tensor<2xindex> + %3644 = shape.cstr_broadcastable %3643, %3640 : tensor<2xindex>, tensor<2xindex> + %3645 = shape.assuming %3644 -> (tensor) { + %19688 = shape.broadcast %3643, %3640 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3642, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3637, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3646 = stablehlo.dot %3645, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1283 = tensor.dim %3618, %c0 : tensor + %3647 = arith.index_cast %dim_1283 : index to i64 + %from_elements_1284 = tensor.from_elements %3647, %c1_i64 : tensor<2xi64> + %3648 = stablehlo.dynamic_reshape %3618, %from_elements_1284 : (tensor, tensor<2xi64>) -> tensor + %dim_1285 = tensor.dim %3615, %c0 : tensor + %3649 = arith.index_cast %dim_1285 : index to i64 + %from_elements_1286 = tensor.from_elements %3649, %c1_i64 : tensor<2xi64> + %3650 = stablehlo.dynamic_reshape %3615, %from_elements_1286 : (tensor, tensor<2xi64>) -> tensor + %3651 = stablehlo.concatenate %3648, %3650, dim = 1 : (tensor, tensor) -> tensor + %3652 = "stablehlo.gather"(%3268, %3651) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3653 = shape.shape_of %3646 : tensor -> tensor<2xindex> + %3654 = shape.shape_of %3652 : tensor -> tensor<2xindex> + %3655 = shape.cstr_broadcastable %3653, %3654 : tensor<2xindex>, tensor<2xindex> + %3656 = shape.assuming %3655 -> (tensor) { + %19688 = shape.broadcast %3653, %3654 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3646, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3652, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3657 = shape.shape_of %3656 : tensor -> tensor<2xindex> + %3658 = stablehlo.dynamic_broadcast_in_dim %3656, %3657, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3659 = stablehlo.dynamic_broadcast_in_dim %213, %3657, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3660 = stablehlo.multiply %3658, %3659 : tensor + %dim_1287 = tensor.dim %3620, %c0 : tensor + %3661 = arith.index_cast %dim_1287 : index to i64 + %dim_1288 = tensor.dim %3656, %c0 : tensor + %3662 = arith.index_cast %dim_1288 : index to i64 + %3663 = arith.maxsi %3661, %3662 : i64 + %3664 = arith.index_cast %3663 : i64 to index + %from_elements_1289 = tensor.from_elements %3664, %c4096 : tensor<2xindex> + %3665 = stablehlo.dynamic_broadcast_in_dim %3620, %from_elements_1289, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1290 = tensor.dim %3665, %c0 : tensor + %3666 = arith.index_cast %dim_1290 : index to i64 + %from_elements_1291 = tensor.from_elements %3666, %c4096_i64 : tensor<2xi64> + %3667 = stablehlo.real_dynamic_slice %3660, %c_22, %from_elements_1291, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1292 = tensor.from_elements %3666, %c4096_i64, %c1_i64 : tensor<3xi64> + %3668 = stablehlo.dynamic_reshape %3665, %from_elements_1292 : (tensor, tensor<3xi64>) -> tensor + %3669 = stablehlo.dynamic_iota %from_elements_1292, dim = 1 : (tensor<3xi64>) -> tensor + %3670 = stablehlo.concatenate %3668, %3669, dim = 2 : (tensor, tensor) -> tensor + %3671 = "stablehlo.scatter"(%3608, %3670, %3667) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3672 = stablehlo.slice %3228 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3673 = stablehlo.reshape %3672 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3674 = stablehlo.custom_call @byteir.non_zero(%3673) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1293 = tensor.dim %3674, %c0 : tensor + %3675 = arith.index_cast %dim_1293 : index to i64 + %from_elements_1294 = tensor.from_elements %3675, %c1_i64 : tensor<2xi64> + %3676 = stablehlo.real_dynamic_slice %3674, %c_22, %from_elements_1294, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1295 = tensor.dim %3676, %c0 : tensor + %3677 = arith.index_cast %dim_1295 : index to i64 + %from_elements_1296 = tensor.from_elements %3677 : tensor<1xi64> + %3678 = stablehlo.dynamic_reshape %3676, %from_elements_1296 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1297 = tensor.from_elements %3675, %c2_i64 : tensor<2xi64> + %3679 = stablehlo.real_dynamic_slice %3674, %c_24, %from_elements_1297, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1298 = tensor.dim %3679, %c0 : tensor + %3680 = arith.index_cast %dim_1298 : index to i64 + %from_elements_1299 = tensor.from_elements %3680 : tensor<1xi64> + %3681 = stablehlo.dynamic_reshape %3679, %from_elements_1299 : (tensor, tensor<1xi64>) -> tensor + %dim_1300 = tensor.dim %3681, %c0 : tensor + %3682 = arith.index_cast %dim_1300 : index to i64 + %from_elements_1301 = tensor.from_elements %3682, %c1_i64 : tensor<2xi64> + %3683 = stablehlo.dynamic_reshape %3681, %from_elements_1301 : (tensor, tensor<2xi64>) -> tensor + %dim_1302 = tensor.dim %3683, %c0 : tensor + %3684 = arith.index_cast %dim_1302 : index to i64 + %from_elements_1303 = tensor.from_elements %c1_i64, %3684, %c4096_i64 : tensor<3xi64> + %3685 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1303, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1304 = tensor.dim %3685, %c1 : tensor<1x?x4096xi64> + %3686 = arith.index_cast %dim_1304 : index to i64 + %from_elements_1305 = tensor.from_elements %c1_i64, %3686, %c4096_i64, %c1_i64 : tensor<4xi64> + %3687 = stablehlo.dynamic_reshape %3685, %from_elements_1305 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3688 = stablehlo.dynamic_broadcast_in_dim %3683, %from_elements_1303, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1306 = tensor.dim %3688, %c1 : tensor<1x?x4096xi64> + %3689 = arith.index_cast %dim_1306 : index to i64 + %from_elements_1307 = tensor.from_elements %c1_i64, %3689, %c4096_i64, %c1_i64 : tensor<4xi64> + %3690 = stablehlo.dynamic_reshape %3688, %from_elements_1307 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3691 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1303, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1308 = tensor.dim %3691, %c1 : tensor<1x?x4096xi64> + %3692 = arith.index_cast %dim_1308 : index to i64 + %from_elements_1309 = tensor.from_elements %c1_i64, %3692, %c4096_i64, %c1_i64 : tensor<4xi64> + %3693 = stablehlo.dynamic_reshape %3691, %from_elements_1309 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3694 = stablehlo.concatenate %3687, %3690, %3693, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3695 = "stablehlo.gather"(%3239, %3694) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3696 = shape.shape_of %3695 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3697 = shape.num_elements %3696 : tensor<3xindex> -> index + %3698 = stablehlo.compute_reshape_shape %3697, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3699 = stablehlo.dynamic_reshape %3695, %3698 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3700 = stablehlo.dot %3699, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3701 = stablehlo.logistic %3700 : tensor + %3702 = shape.shape_of %3701 : tensor -> tensor<2xindex> + %3703 = shape.shape_of %3700 : tensor -> tensor<2xindex> + %3704 = shape.cstr_broadcastable %3702, %3703 : tensor<2xindex>, tensor<2xindex> + %3705 = shape.assuming %3704 -> (tensor) { + %19688 = shape.broadcast %3702, %3703 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3701, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3700, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3706 = shape.shape_of %3705 : tensor -> tensor<2xindex> + %3707 = shape.cstr_broadcastable %3706, %3703 : tensor<2xindex>, tensor<2xindex> + %3708 = shape.assuming %3707 -> (tensor) { + %19688 = shape.broadcast %3706, %3703 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3705, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3700, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3709 = stablehlo.dot %3708, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1310 = tensor.dim %3681, %c0 : tensor + %3710 = arith.index_cast %dim_1310 : index to i64 + %from_elements_1311 = tensor.from_elements %3710, %c1_i64 : tensor<2xi64> + %3711 = stablehlo.dynamic_reshape %3681, %from_elements_1311 : (tensor, tensor<2xi64>) -> tensor + %dim_1312 = tensor.dim %3678, %c0 : tensor + %3712 = arith.index_cast %dim_1312 : index to i64 + %from_elements_1313 = tensor.from_elements %3712, %c1_i64 : tensor<2xi64> + %3713 = stablehlo.dynamic_reshape %3678, %from_elements_1313 : (tensor, tensor<2xi64>) -> tensor + %3714 = stablehlo.concatenate %3711, %3713, dim = 1 : (tensor, tensor) -> tensor + %3715 = "stablehlo.gather"(%3268, %3714) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3716 = shape.shape_of %3709 : tensor -> tensor<2xindex> + %3717 = shape.shape_of %3715 : tensor -> tensor<2xindex> + %3718 = shape.cstr_broadcastable %3716, %3717 : tensor<2xindex>, tensor<2xindex> + %3719 = shape.assuming %3718 -> (tensor) { + %19688 = shape.broadcast %3716, %3717 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3709, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3715, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3720 = shape.shape_of %3719 : tensor -> tensor<2xindex> + %3721 = stablehlo.dynamic_broadcast_in_dim %3719, %3720, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3722 = stablehlo.dynamic_broadcast_in_dim %213, %3720, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3723 = stablehlo.multiply %3721, %3722 : tensor + %dim_1314 = tensor.dim %3683, %c0 : tensor + %3724 = arith.index_cast %dim_1314 : index to i64 + %dim_1315 = tensor.dim %3719, %c0 : tensor + %3725 = arith.index_cast %dim_1315 : index to i64 + %3726 = arith.maxsi %3724, %3725 : i64 + %3727 = arith.index_cast %3726 : i64 to index + %from_elements_1316 = tensor.from_elements %3727, %c4096 : tensor<2xindex> + %3728 = stablehlo.dynamic_broadcast_in_dim %3683, %from_elements_1316, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1317 = tensor.dim %3728, %c0 : tensor + %3729 = arith.index_cast %dim_1317 : index to i64 + %from_elements_1318 = tensor.from_elements %3729, %c4096_i64 : tensor<2xi64> + %3730 = stablehlo.real_dynamic_slice %3723, %c_22, %from_elements_1318, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1319 = tensor.from_elements %3729, %c4096_i64, %c1_i64 : tensor<3xi64> + %3731 = stablehlo.dynamic_reshape %3728, %from_elements_1319 : (tensor, tensor<3xi64>) -> tensor + %3732 = stablehlo.dynamic_iota %from_elements_1319, dim = 1 : (tensor<3xi64>) -> tensor + %3733 = stablehlo.concatenate %3731, %3732, dim = 2 : (tensor, tensor) -> tensor + %3734 = "stablehlo.scatter"(%3671, %3733, %3730) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3735 = stablehlo.reshape %3734 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %3736 = stablehlo.add %3201, %3735 : tensor<3x1x4096xf32> + %3737 = stablehlo.broadcast_in_dim %3736, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %3738 = stablehlo.power %3737, %15 : tensor<3x1x4096xf32> + %3739 = stablehlo.reduce(%3738 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %3740 = stablehlo.reshape %3739 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %3741 = stablehlo.broadcast_in_dim %3740, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %3742 = stablehlo.divide %3741, %21 : tensor<3x1x1xf32> + %3743 = stablehlo.broadcast_in_dim %3742, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %3744 = stablehlo.add %3743, %25 : tensor<3x1x1xf32> + %3745 = stablehlo.rsqrt %3744 : tensor<3x1x1xf32> + %3746 = stablehlo.broadcast_in_dim %3745, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %3747 = stablehlo.multiply %3737, %3746 : tensor<3x1x4096xf32> + %3748 = stablehlo.broadcast_in_dim %3747, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %3749 = stablehlo.multiply %3748, %31 : tensor<3x1x4096xf32> + %3750 = stablehlo.reshape %3749 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %3751 = stablehlo.dot %3750, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %3752 = stablehlo.reshape %3751 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %3753 = stablehlo.dot %3750, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %3754 = stablehlo.reshape %3753 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %3755 = stablehlo.reshape %3752 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %3756 = stablehlo.transpose %3755, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %3757 = stablehlo.reshape %3754 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %3758 = stablehlo.transpose %3757, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %3759 = stablehlo.slice %arg12 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %3760 = stablehlo.slice %arg13 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %3761 = "stablehlo.gather"(%3759, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %3762 = stablehlo.reshape %3761 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %3763 = "stablehlo.gather"(%3760, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %3764 = stablehlo.reshape %3763 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %3765 = stablehlo.broadcast_in_dim %3756, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %3766 = stablehlo.broadcast_in_dim %3762, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %3767 = stablehlo.multiply %3765, %3766 : tensor<3x32x1x128xf32> + %3768 = stablehlo.slice %3756 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %3769 = stablehlo.slice %3756 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %3770 = stablehlo.negate %3769 : tensor<3x32x1x64xf32> + %3771 = stablehlo.concatenate %3770, %3768, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %3772 = stablehlo.broadcast_in_dim %3771, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %3773 = stablehlo.broadcast_in_dim %3764, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %3774 = stablehlo.multiply %3772, %3773 : tensor<3x32x1x128xf32> + %3775 = stablehlo.add %3767, %3774 : tensor<3x32x1x128xf32> + %3776 = stablehlo.broadcast_in_dim %3758, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %3777 = stablehlo.broadcast_in_dim %3762, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %3778 = stablehlo.multiply %3776, %3777 : tensor<3x8x1x128xf32> + %3779 = stablehlo.slice %3758 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %3780 = stablehlo.slice %3758 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %3781 = stablehlo.negate %3780 : tensor<3x8x1x64xf32> + %3782 = stablehlo.concatenate %3781, %3779, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %3783 = stablehlo.broadcast_in_dim %3782, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %3784 = stablehlo.broadcast_in_dim %3764, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %3785 = stablehlo.multiply %3783, %3784 : tensor<3x8x1x128xf32> + %3786 = stablehlo.add %3778, %3785 : tensor<3x8x1x128xf32> + %3787 = stablehlo.concatenate %arg77, %3786, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %3788 = stablehlo.concatenate %arg78, %3758, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %3789 = stablehlo.reshape %3787 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %3790 = stablehlo.broadcast_in_dim %3789, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %3791 = stablehlo.reshape %3790 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %3792 = stablehlo.reshape %3788 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %3793 = stablehlo.broadcast_in_dim %3792, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %3794 = stablehlo.reshape %3793 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %3795 = stablehlo.transpose %3791, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %3796 = stablehlo.reshape %3775 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %3797 = stablehlo.reshape %3795 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %3798 = stablehlo.broadcast_in_dim %3797, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %3799 = stablehlo.dot_general %3796, %3798, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %3800 = stablehlo.reshape %3799 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %3801 = stablehlo.broadcast_in_dim %3800, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %3802 = stablehlo.divide %3801, %89 : tensor<3x32x1x8xf32> + %3803 = stablehlo.custom_call @byteir.softmax(%3802) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %3804 = stablehlo.reshape %3803 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %3805 = stablehlo.reshape %3794 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %3806 = stablehlo.broadcast_in_dim %3805, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %3807 = stablehlo.dot_general %3804, %3806, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %3808 = stablehlo.reshape %3807 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %3809 = stablehlo.transpose %3808, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %3810 = stablehlo.reshape %3809 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %3811 = stablehlo.reshape %3810 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %3812 = stablehlo.dot %3811, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %3813 = stablehlo.reshape %3812 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %3814 = stablehlo.add %3736, %3813 : tensor<3x1x4096xf32> + %3815 = stablehlo.broadcast_in_dim %3814, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %3816 = stablehlo.power %3815, %15 : tensor<3x1x4096xf32> + %3817 = stablehlo.reduce(%3816 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %3818 = stablehlo.reshape %3817 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %3819 = stablehlo.broadcast_in_dim %3818, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %3820 = stablehlo.divide %3819, %21 : tensor<3x1x1xf32> + %3821 = stablehlo.broadcast_in_dim %3820, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %3822 = stablehlo.add %3821, %25 : tensor<3x1x1xf32> + %3823 = stablehlo.rsqrt %3822 : tensor<3x1x1xf32> + %3824 = stablehlo.broadcast_in_dim %3823, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %3825 = stablehlo.multiply %3815, %3824 : tensor<3x1x4096xf32> + %3826 = stablehlo.broadcast_in_dim %3825, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %3827 = stablehlo.multiply %3826, %31 : tensor<3x1x4096xf32> + %3828 = stablehlo.reshape %3827 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %3829 = stablehlo.dot %3828, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %3830 = stablehlo.custom_call @byteir.softmax(%3829) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %3831:2 = stablehlo.custom_call @byteir.top_k(%3830) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %3832 = stablehlo.reduce(%3831#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %3833 = stablehlo.reshape %3832 : (tensor<3xf32>) -> tensor<3x1xf32> + %3834 = stablehlo.broadcast_in_dim %3831#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %3835 = stablehlo.broadcast_in_dim %3833, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %3836 = stablehlo.divide %3834, %3835 : tensor<3x2xf32> + %3837 = stablehlo.reshape %3831#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %3838 = stablehlo.broadcast_in_dim %3837, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %3839 = stablehlo.compare EQ, %3838, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %3840 = stablehlo.convert %3839 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %3841 = stablehlo.transpose %3840, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %3842 = stablehlo.slice %3841 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3843 = stablehlo.reshape %3842 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3844 = stablehlo.custom_call @byteir.non_zero(%3843) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1320 = tensor.dim %3844, %c0 : tensor + %3845 = arith.index_cast %dim_1320 : index to i64 + %from_elements_1321 = tensor.from_elements %3845, %c1_i64 : tensor<2xi64> + %3846 = stablehlo.real_dynamic_slice %3844, %c_22, %from_elements_1321, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1322 = tensor.dim %3846, %c0 : tensor + %3847 = arith.index_cast %dim_1322 : index to i64 + %from_elements_1323 = tensor.from_elements %3847 : tensor<1xi64> + %3848 = stablehlo.dynamic_reshape %3846, %from_elements_1323 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1324 = tensor.from_elements %3845, %c2_i64 : tensor<2xi64> + %3849 = stablehlo.real_dynamic_slice %3844, %c_24, %from_elements_1324, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1325 = tensor.dim %3849, %c0 : tensor + %3850 = arith.index_cast %dim_1325 : index to i64 + %from_elements_1326 = tensor.from_elements %3850 : tensor<1xi64> + %3851 = stablehlo.dynamic_reshape %3849, %from_elements_1326 : (tensor, tensor<1xi64>) -> tensor + %3852 = stablehlo.reshape %3828 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_1327 = tensor.dim %3851, %c0 : tensor + %3853 = arith.index_cast %dim_1327 : index to i64 + %from_elements_1328 = tensor.from_elements %3853, %c1_i64 : tensor<2xi64> + %3854 = stablehlo.dynamic_reshape %3851, %from_elements_1328 : (tensor, tensor<2xi64>) -> tensor + %dim_1329 = tensor.dim %3854, %c0 : tensor + %3855 = arith.index_cast %dim_1329 : index to i64 + %from_elements_1330 = tensor.from_elements %c1_i64, %3855, %c4096_i64 : tensor<3xi64> + %3856 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1330, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1331 = tensor.dim %3856, %c1 : tensor<1x?x4096xi64> + %3857 = arith.index_cast %dim_1331 : index to i64 + %from_elements_1332 = tensor.from_elements %c1_i64, %3857, %c4096_i64, %c1_i64 : tensor<4xi64> + %3858 = stablehlo.dynamic_reshape %3856, %from_elements_1332 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3859 = stablehlo.dynamic_broadcast_in_dim %3854, %from_elements_1330, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1333 = tensor.dim %3859, %c1 : tensor<1x?x4096xi64> + %3860 = arith.index_cast %dim_1333 : index to i64 + %from_elements_1334 = tensor.from_elements %c1_i64, %3860, %c4096_i64, %c1_i64 : tensor<4xi64> + %3861 = stablehlo.dynamic_reshape %3859, %from_elements_1334 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3862 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1330, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1335 = tensor.dim %3862, %c1 : tensor<1x?x4096xi64> + %3863 = arith.index_cast %dim_1335 : index to i64 + %from_elements_1336 = tensor.from_elements %c1_i64, %3863, %c4096_i64, %c1_i64 : tensor<4xi64> + %3864 = stablehlo.dynamic_reshape %3862, %from_elements_1336 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3865 = stablehlo.concatenate %3858, %3861, %3864, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3866 = "stablehlo.gather"(%3852, %3865) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3867 = shape.shape_of %3866 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3868 = shape.num_elements %3867 : tensor<3xindex> -> index + %3869 = stablehlo.compute_reshape_shape %3868, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3870 = stablehlo.dynamic_reshape %3866, %3869 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3871 = stablehlo.dot %3870, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3872 = stablehlo.logistic %3871 : tensor + %3873 = shape.shape_of %3872 : tensor -> tensor<2xindex> + %3874 = shape.shape_of %3871 : tensor -> tensor<2xindex> + %3875 = shape.cstr_broadcastable %3873, %3874 : tensor<2xindex>, tensor<2xindex> + %3876 = shape.assuming %3875 -> (tensor) { + %19688 = shape.broadcast %3873, %3874 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3872, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3871, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3877 = shape.shape_of %3876 : tensor -> tensor<2xindex> + %3878 = shape.cstr_broadcastable %3877, %3874 : tensor<2xindex>, tensor<2xindex> + %3879 = shape.assuming %3878 -> (tensor) { + %19688 = shape.broadcast %3877, %3874 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3876, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3871, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3880 = stablehlo.dot %3879, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %3881 = stablehlo.reshape %3836 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_1337 = tensor.dim %3851, %c0 : tensor + %3882 = arith.index_cast %dim_1337 : index to i64 + %from_elements_1338 = tensor.from_elements %3882, %c1_i64 : tensor<2xi64> + %3883 = stablehlo.dynamic_reshape %3851, %from_elements_1338 : (tensor, tensor<2xi64>) -> tensor + %dim_1339 = tensor.dim %3848, %c0 : tensor + %3884 = arith.index_cast %dim_1339 : index to i64 + %from_elements_1340 = tensor.from_elements %3884, %c1_i64 : tensor<2xi64> + %3885 = stablehlo.dynamic_reshape %3848, %from_elements_1340 : (tensor, tensor<2xi64>) -> tensor + %3886 = stablehlo.concatenate %3883, %3885, dim = 1 : (tensor, tensor) -> tensor + %3887 = "stablehlo.gather"(%3881, %3886) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3888 = shape.shape_of %3880 : tensor -> tensor<2xindex> + %3889 = shape.shape_of %3887 : tensor -> tensor<2xindex> + %3890 = shape.cstr_broadcastable %3888, %3889 : tensor<2xindex>, tensor<2xindex> + %3891 = shape.assuming %3890 -> (tensor) { + %19688 = shape.broadcast %3888, %3889 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3880, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3887, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3892 = shape.shape_of %3891 : tensor -> tensor<2xindex> + %3893 = stablehlo.dynamic_broadcast_in_dim %3891, %3892, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3894 = stablehlo.dynamic_broadcast_in_dim %213, %3892, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3895 = stablehlo.multiply %3893, %3894 : tensor + %dim_1341 = tensor.dim %3854, %c0 : tensor + %3896 = arith.index_cast %dim_1341 : index to i64 + %dim_1342 = tensor.dim %3891, %c0 : tensor + %3897 = arith.index_cast %dim_1342 : index to i64 + %3898 = arith.maxsi %3896, %3897 : i64 + %3899 = arith.index_cast %3898 : i64 to index + %from_elements_1343 = tensor.from_elements %3899, %c4096 : tensor<2xindex> + %3900 = stablehlo.dynamic_broadcast_in_dim %3854, %from_elements_1343, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1344 = tensor.dim %3900, %c0 : tensor + %3901 = arith.index_cast %dim_1344 : index to i64 + %from_elements_1345 = tensor.from_elements %3901, %c4096_i64 : tensor<2xi64> + %3902 = stablehlo.real_dynamic_slice %3895, %c_22, %from_elements_1345, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1346 = tensor.from_elements %3901, %c4096_i64, %c1_i64 : tensor<3xi64> + %3903 = stablehlo.dynamic_reshape %3900, %from_elements_1346 : (tensor, tensor<3xi64>) -> tensor + %3904 = stablehlo.dynamic_iota %from_elements_1346, dim = 1 : (tensor<3xi64>) -> tensor + %3905 = stablehlo.concatenate %3903, %3904, dim = 2 : (tensor, tensor) -> tensor + %3906 = "stablehlo.scatter"(%cst_2, %3905, %3902) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3907 = stablehlo.slice %3841 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3908 = stablehlo.reshape %3907 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3909 = stablehlo.custom_call @byteir.non_zero(%3908) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1347 = tensor.dim %3909, %c0 : tensor + %3910 = arith.index_cast %dim_1347 : index to i64 + %from_elements_1348 = tensor.from_elements %3910, %c1_i64 : tensor<2xi64> + %3911 = stablehlo.real_dynamic_slice %3909, %c_22, %from_elements_1348, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1349 = tensor.dim %3911, %c0 : tensor + %3912 = arith.index_cast %dim_1349 : index to i64 + %from_elements_1350 = tensor.from_elements %3912 : tensor<1xi64> + %3913 = stablehlo.dynamic_reshape %3911, %from_elements_1350 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1351 = tensor.from_elements %3910, %c2_i64 : tensor<2xi64> + %3914 = stablehlo.real_dynamic_slice %3909, %c_24, %from_elements_1351, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1352 = tensor.dim %3914, %c0 : tensor + %3915 = arith.index_cast %dim_1352 : index to i64 + %from_elements_1353 = tensor.from_elements %3915 : tensor<1xi64> + %3916 = stablehlo.dynamic_reshape %3914, %from_elements_1353 : (tensor, tensor<1xi64>) -> tensor + %dim_1354 = tensor.dim %3916, %c0 : tensor + %3917 = arith.index_cast %dim_1354 : index to i64 + %from_elements_1355 = tensor.from_elements %3917, %c1_i64 : tensor<2xi64> + %3918 = stablehlo.dynamic_reshape %3916, %from_elements_1355 : (tensor, tensor<2xi64>) -> tensor + %dim_1356 = tensor.dim %3918, %c0 : tensor + %3919 = arith.index_cast %dim_1356 : index to i64 + %from_elements_1357 = tensor.from_elements %c1_i64, %3919, %c4096_i64 : tensor<3xi64> + %3920 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1357, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1358 = tensor.dim %3920, %c1 : tensor<1x?x4096xi64> + %3921 = arith.index_cast %dim_1358 : index to i64 + %from_elements_1359 = tensor.from_elements %c1_i64, %3921, %c4096_i64, %c1_i64 : tensor<4xi64> + %3922 = stablehlo.dynamic_reshape %3920, %from_elements_1359 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3923 = stablehlo.dynamic_broadcast_in_dim %3918, %from_elements_1357, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1360 = tensor.dim %3923, %c1 : tensor<1x?x4096xi64> + %3924 = arith.index_cast %dim_1360 : index to i64 + %from_elements_1361 = tensor.from_elements %c1_i64, %3924, %c4096_i64, %c1_i64 : tensor<4xi64> + %3925 = stablehlo.dynamic_reshape %3923, %from_elements_1361 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3926 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1357, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1362 = tensor.dim %3926, %c1 : tensor<1x?x4096xi64> + %3927 = arith.index_cast %dim_1362 : index to i64 + %from_elements_1363 = tensor.from_elements %c1_i64, %3927, %c4096_i64, %c1_i64 : tensor<4xi64> + %3928 = stablehlo.dynamic_reshape %3926, %from_elements_1363 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3929 = stablehlo.concatenate %3922, %3925, %3928, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3930 = "stablehlo.gather"(%3852, %3929) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3931 = shape.shape_of %3930 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3932 = shape.num_elements %3931 : tensor<3xindex> -> index + %3933 = stablehlo.compute_reshape_shape %3932, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3934 = stablehlo.dynamic_reshape %3930, %3933 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3935 = stablehlo.dot %3934, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3936 = stablehlo.logistic %3935 : tensor + %3937 = shape.shape_of %3936 : tensor -> tensor<2xindex> + %3938 = shape.shape_of %3935 : tensor -> tensor<2xindex> + %3939 = shape.cstr_broadcastable %3937, %3938 : tensor<2xindex>, tensor<2xindex> + %3940 = shape.assuming %3939 -> (tensor) { + %19688 = shape.broadcast %3937, %3938 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3936, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3935, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3941 = shape.shape_of %3940 : tensor -> tensor<2xindex> + %3942 = shape.cstr_broadcastable %3941, %3938 : tensor<2xindex>, tensor<2xindex> + %3943 = shape.assuming %3942 -> (tensor) { + %19688 = shape.broadcast %3941, %3938 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3940, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3935, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3944 = stablehlo.dot %3943, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1364 = tensor.dim %3916, %c0 : tensor + %3945 = arith.index_cast %dim_1364 : index to i64 + %from_elements_1365 = tensor.from_elements %3945, %c1_i64 : tensor<2xi64> + %3946 = stablehlo.dynamic_reshape %3916, %from_elements_1365 : (tensor, tensor<2xi64>) -> tensor + %dim_1366 = tensor.dim %3913, %c0 : tensor + %3947 = arith.index_cast %dim_1366 : index to i64 + %from_elements_1367 = tensor.from_elements %3947, %c1_i64 : tensor<2xi64> + %3948 = stablehlo.dynamic_reshape %3913, %from_elements_1367 : (tensor, tensor<2xi64>) -> tensor + %3949 = stablehlo.concatenate %3946, %3948, dim = 1 : (tensor, tensor) -> tensor + %3950 = "stablehlo.gather"(%3881, %3949) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %3951 = shape.shape_of %3944 : tensor -> tensor<2xindex> + %3952 = shape.shape_of %3950 : tensor -> tensor<2xindex> + %3953 = shape.cstr_broadcastable %3951, %3952 : tensor<2xindex>, tensor<2xindex> + %3954 = shape.assuming %3953 -> (tensor) { + %19688 = shape.broadcast %3951, %3952 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3944, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3950, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %3955 = shape.shape_of %3954 : tensor -> tensor<2xindex> + %3956 = stablehlo.dynamic_broadcast_in_dim %3954, %3955, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %3957 = stablehlo.dynamic_broadcast_in_dim %213, %3955, dims = [] : (tensor, tensor<2xindex>) -> tensor + %3958 = stablehlo.multiply %3956, %3957 : tensor + %dim_1368 = tensor.dim %3918, %c0 : tensor + %3959 = arith.index_cast %dim_1368 : index to i64 + %dim_1369 = tensor.dim %3954, %c0 : tensor + %3960 = arith.index_cast %dim_1369 : index to i64 + %3961 = arith.maxsi %3959, %3960 : i64 + %3962 = arith.index_cast %3961 : i64 to index + %from_elements_1370 = tensor.from_elements %3962, %c4096 : tensor<2xindex> + %3963 = stablehlo.dynamic_broadcast_in_dim %3918, %from_elements_1370, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1371 = tensor.dim %3963, %c0 : tensor + %3964 = arith.index_cast %dim_1371 : index to i64 + %from_elements_1372 = tensor.from_elements %3964, %c4096_i64 : tensor<2xi64> + %3965 = stablehlo.real_dynamic_slice %3958, %c_22, %from_elements_1372, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1373 = tensor.from_elements %3964, %c4096_i64, %c1_i64 : tensor<3xi64> + %3966 = stablehlo.dynamic_reshape %3963, %from_elements_1373 : (tensor, tensor<3xi64>) -> tensor + %3967 = stablehlo.dynamic_iota %from_elements_1373, dim = 1 : (tensor<3xi64>) -> tensor + %3968 = stablehlo.concatenate %3966, %3967, dim = 2 : (tensor, tensor) -> tensor + %3969 = "stablehlo.scatter"(%3906, %3968, %3965) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %3970 = stablehlo.slice %3841 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %3971 = stablehlo.reshape %3970 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %3972 = stablehlo.custom_call @byteir.non_zero(%3971) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1374 = tensor.dim %3972, %c0 : tensor + %3973 = arith.index_cast %dim_1374 : index to i64 + %from_elements_1375 = tensor.from_elements %3973, %c1_i64 : tensor<2xi64> + %3974 = stablehlo.real_dynamic_slice %3972, %c_22, %from_elements_1375, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1376 = tensor.dim %3974, %c0 : tensor + %3975 = arith.index_cast %dim_1376 : index to i64 + %from_elements_1377 = tensor.from_elements %3975 : tensor<1xi64> + %3976 = stablehlo.dynamic_reshape %3974, %from_elements_1377 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1378 = tensor.from_elements %3973, %c2_i64 : tensor<2xi64> + %3977 = stablehlo.real_dynamic_slice %3972, %c_24, %from_elements_1378, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1379 = tensor.dim %3977, %c0 : tensor + %3978 = arith.index_cast %dim_1379 : index to i64 + %from_elements_1380 = tensor.from_elements %3978 : tensor<1xi64> + %3979 = stablehlo.dynamic_reshape %3977, %from_elements_1380 : (tensor, tensor<1xi64>) -> tensor + %dim_1381 = tensor.dim %3979, %c0 : tensor + %3980 = arith.index_cast %dim_1381 : index to i64 + %from_elements_1382 = tensor.from_elements %3980, %c1_i64 : tensor<2xi64> + %3981 = stablehlo.dynamic_reshape %3979, %from_elements_1382 : (tensor, tensor<2xi64>) -> tensor + %dim_1383 = tensor.dim %3981, %c0 : tensor + %3982 = arith.index_cast %dim_1383 : index to i64 + %from_elements_1384 = tensor.from_elements %c1_i64, %3982, %c4096_i64 : tensor<3xi64> + %3983 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1384, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1385 = tensor.dim %3983, %c1 : tensor<1x?x4096xi64> + %3984 = arith.index_cast %dim_1385 : index to i64 + %from_elements_1386 = tensor.from_elements %c1_i64, %3984, %c4096_i64, %c1_i64 : tensor<4xi64> + %3985 = stablehlo.dynamic_reshape %3983, %from_elements_1386 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3986 = stablehlo.dynamic_broadcast_in_dim %3981, %from_elements_1384, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1387 = tensor.dim %3986, %c1 : tensor<1x?x4096xi64> + %3987 = arith.index_cast %dim_1387 : index to i64 + %from_elements_1388 = tensor.from_elements %c1_i64, %3987, %c4096_i64, %c1_i64 : tensor<4xi64> + %3988 = stablehlo.dynamic_reshape %3986, %from_elements_1388 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3989 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1384, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1389 = tensor.dim %3989, %c1 : tensor<1x?x4096xi64> + %3990 = arith.index_cast %dim_1389 : index to i64 + %from_elements_1390 = tensor.from_elements %c1_i64, %3990, %c4096_i64, %c1_i64 : tensor<4xi64> + %3991 = stablehlo.dynamic_reshape %3989, %from_elements_1390 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %3992 = stablehlo.concatenate %3985, %3988, %3991, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %3993 = "stablehlo.gather"(%3852, %3992) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %3994 = shape.shape_of %3993 : tensor<1x?x4096xf32> -> tensor<3xindex> + %3995 = shape.num_elements %3994 : tensor<3xindex> -> index + %3996 = stablehlo.compute_reshape_shape %3995, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %3997 = stablehlo.dynamic_reshape %3993, %3996 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %3998 = stablehlo.dot %3997, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %3999 = stablehlo.logistic %3998 : tensor + %4000 = shape.shape_of %3999 : tensor -> tensor<2xindex> + %4001 = shape.shape_of %3998 : tensor -> tensor<2xindex> + %4002 = shape.cstr_broadcastable %4000, %4001 : tensor<2xindex>, tensor<2xindex> + %4003 = shape.assuming %4002 -> (tensor) { + %19688 = shape.broadcast %4000, %4001 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %3999, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3998, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4004 = shape.shape_of %4003 : tensor -> tensor<2xindex> + %4005 = shape.cstr_broadcastable %4004, %4001 : tensor<2xindex>, tensor<2xindex> + %4006 = shape.assuming %4005 -> (tensor) { + %19688 = shape.broadcast %4004, %4001 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4003, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %3998, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4007 = stablehlo.dot %4006, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1391 = tensor.dim %3979, %c0 : tensor + %4008 = arith.index_cast %dim_1391 : index to i64 + %from_elements_1392 = tensor.from_elements %4008, %c1_i64 : tensor<2xi64> + %4009 = stablehlo.dynamic_reshape %3979, %from_elements_1392 : (tensor, tensor<2xi64>) -> tensor + %dim_1393 = tensor.dim %3976, %c0 : tensor + %4010 = arith.index_cast %dim_1393 : index to i64 + %from_elements_1394 = tensor.from_elements %4010, %c1_i64 : tensor<2xi64> + %4011 = stablehlo.dynamic_reshape %3976, %from_elements_1394 : (tensor, tensor<2xi64>) -> tensor + %4012 = stablehlo.concatenate %4009, %4011, dim = 1 : (tensor, tensor) -> tensor + %4013 = "stablehlo.gather"(%3881, %4012) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4014 = shape.shape_of %4007 : tensor -> tensor<2xindex> + %4015 = shape.shape_of %4013 : tensor -> tensor<2xindex> + %4016 = shape.cstr_broadcastable %4014, %4015 : tensor<2xindex>, tensor<2xindex> + %4017 = shape.assuming %4016 -> (tensor) { + %19688 = shape.broadcast %4014, %4015 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4007, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4013, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4018 = shape.shape_of %4017 : tensor -> tensor<2xindex> + %4019 = stablehlo.dynamic_broadcast_in_dim %4017, %4018, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4020 = stablehlo.dynamic_broadcast_in_dim %213, %4018, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4021 = stablehlo.multiply %4019, %4020 : tensor + %dim_1395 = tensor.dim %3981, %c0 : tensor + %4022 = arith.index_cast %dim_1395 : index to i64 + %dim_1396 = tensor.dim %4017, %c0 : tensor + %4023 = arith.index_cast %dim_1396 : index to i64 + %4024 = arith.maxsi %4022, %4023 : i64 + %4025 = arith.index_cast %4024 : i64 to index + %from_elements_1397 = tensor.from_elements %4025, %c4096 : tensor<2xindex> + %4026 = stablehlo.dynamic_broadcast_in_dim %3981, %from_elements_1397, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1398 = tensor.dim %4026, %c0 : tensor + %4027 = arith.index_cast %dim_1398 : index to i64 + %from_elements_1399 = tensor.from_elements %4027, %c4096_i64 : tensor<2xi64> + %4028 = stablehlo.real_dynamic_slice %4021, %c_22, %from_elements_1399, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1400 = tensor.from_elements %4027, %c4096_i64, %c1_i64 : tensor<3xi64> + %4029 = stablehlo.dynamic_reshape %4026, %from_elements_1400 : (tensor, tensor<3xi64>) -> tensor + %4030 = stablehlo.dynamic_iota %from_elements_1400, dim = 1 : (tensor<3xi64>) -> tensor + %4031 = stablehlo.concatenate %4029, %4030, dim = 2 : (tensor, tensor) -> tensor + %4032 = "stablehlo.scatter"(%3969, %4031, %4028) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4033 = stablehlo.slice %3841 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4034 = stablehlo.reshape %4033 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4035 = stablehlo.custom_call @byteir.non_zero(%4034) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1401 = tensor.dim %4035, %c0 : tensor + %4036 = arith.index_cast %dim_1401 : index to i64 + %from_elements_1402 = tensor.from_elements %4036, %c1_i64 : tensor<2xi64> + %4037 = stablehlo.real_dynamic_slice %4035, %c_22, %from_elements_1402, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1403 = tensor.dim %4037, %c0 : tensor + %4038 = arith.index_cast %dim_1403 : index to i64 + %from_elements_1404 = tensor.from_elements %4038 : tensor<1xi64> + %4039 = stablehlo.dynamic_reshape %4037, %from_elements_1404 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1405 = tensor.from_elements %4036, %c2_i64 : tensor<2xi64> + %4040 = stablehlo.real_dynamic_slice %4035, %c_24, %from_elements_1405, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1406 = tensor.dim %4040, %c0 : tensor + %4041 = arith.index_cast %dim_1406 : index to i64 + %from_elements_1407 = tensor.from_elements %4041 : tensor<1xi64> + %4042 = stablehlo.dynamic_reshape %4040, %from_elements_1407 : (tensor, tensor<1xi64>) -> tensor + %dim_1408 = tensor.dim %4042, %c0 : tensor + %4043 = arith.index_cast %dim_1408 : index to i64 + %from_elements_1409 = tensor.from_elements %4043, %c1_i64 : tensor<2xi64> + %4044 = stablehlo.dynamic_reshape %4042, %from_elements_1409 : (tensor, tensor<2xi64>) -> tensor + %dim_1410 = tensor.dim %4044, %c0 : tensor + %4045 = arith.index_cast %dim_1410 : index to i64 + %from_elements_1411 = tensor.from_elements %c1_i64, %4045, %c4096_i64 : tensor<3xi64> + %4046 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1411, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1412 = tensor.dim %4046, %c1 : tensor<1x?x4096xi64> + %4047 = arith.index_cast %dim_1412 : index to i64 + %from_elements_1413 = tensor.from_elements %c1_i64, %4047, %c4096_i64, %c1_i64 : tensor<4xi64> + %4048 = stablehlo.dynamic_reshape %4046, %from_elements_1413 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4049 = stablehlo.dynamic_broadcast_in_dim %4044, %from_elements_1411, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1414 = tensor.dim %4049, %c1 : tensor<1x?x4096xi64> + %4050 = arith.index_cast %dim_1414 : index to i64 + %from_elements_1415 = tensor.from_elements %c1_i64, %4050, %c4096_i64, %c1_i64 : tensor<4xi64> + %4051 = stablehlo.dynamic_reshape %4049, %from_elements_1415 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4052 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1411, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1416 = tensor.dim %4052, %c1 : tensor<1x?x4096xi64> + %4053 = arith.index_cast %dim_1416 : index to i64 + %from_elements_1417 = tensor.from_elements %c1_i64, %4053, %c4096_i64, %c1_i64 : tensor<4xi64> + %4054 = stablehlo.dynamic_reshape %4052, %from_elements_1417 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4055 = stablehlo.concatenate %4048, %4051, %4054, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4056 = "stablehlo.gather"(%3852, %4055) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4057 = shape.shape_of %4056 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4058 = shape.num_elements %4057 : tensor<3xindex> -> index + %4059 = stablehlo.compute_reshape_shape %4058, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4060 = stablehlo.dynamic_reshape %4056, %4059 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4061 = stablehlo.dot %4060, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4062 = stablehlo.logistic %4061 : tensor + %4063 = shape.shape_of %4062 : tensor -> tensor<2xindex> + %4064 = shape.shape_of %4061 : tensor -> tensor<2xindex> + %4065 = shape.cstr_broadcastable %4063, %4064 : tensor<2xindex>, tensor<2xindex> + %4066 = shape.assuming %4065 -> (tensor) { + %19688 = shape.broadcast %4063, %4064 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4062, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4061, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4067 = shape.shape_of %4066 : tensor -> tensor<2xindex> + %4068 = shape.cstr_broadcastable %4067, %4064 : tensor<2xindex>, tensor<2xindex> + %4069 = shape.assuming %4068 -> (tensor) { + %19688 = shape.broadcast %4067, %4064 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4066, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4061, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4070 = stablehlo.dot %4069, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1418 = tensor.dim %4042, %c0 : tensor + %4071 = arith.index_cast %dim_1418 : index to i64 + %from_elements_1419 = tensor.from_elements %4071, %c1_i64 : tensor<2xi64> + %4072 = stablehlo.dynamic_reshape %4042, %from_elements_1419 : (tensor, tensor<2xi64>) -> tensor + %dim_1420 = tensor.dim %4039, %c0 : tensor + %4073 = arith.index_cast %dim_1420 : index to i64 + %from_elements_1421 = tensor.from_elements %4073, %c1_i64 : tensor<2xi64> + %4074 = stablehlo.dynamic_reshape %4039, %from_elements_1421 : (tensor, tensor<2xi64>) -> tensor + %4075 = stablehlo.concatenate %4072, %4074, dim = 1 : (tensor, tensor) -> tensor + %4076 = "stablehlo.gather"(%3881, %4075) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4077 = shape.shape_of %4070 : tensor -> tensor<2xindex> + %4078 = shape.shape_of %4076 : tensor -> tensor<2xindex> + %4079 = shape.cstr_broadcastable %4077, %4078 : tensor<2xindex>, tensor<2xindex> + %4080 = shape.assuming %4079 -> (tensor) { + %19688 = shape.broadcast %4077, %4078 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4070, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4076, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4081 = shape.shape_of %4080 : tensor -> tensor<2xindex> + %4082 = stablehlo.dynamic_broadcast_in_dim %4080, %4081, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4083 = stablehlo.dynamic_broadcast_in_dim %213, %4081, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4084 = stablehlo.multiply %4082, %4083 : tensor + %dim_1422 = tensor.dim %4044, %c0 : tensor + %4085 = arith.index_cast %dim_1422 : index to i64 + %dim_1423 = tensor.dim %4080, %c0 : tensor + %4086 = arith.index_cast %dim_1423 : index to i64 + %4087 = arith.maxsi %4085, %4086 : i64 + %4088 = arith.index_cast %4087 : i64 to index + %from_elements_1424 = tensor.from_elements %4088, %c4096 : tensor<2xindex> + %4089 = stablehlo.dynamic_broadcast_in_dim %4044, %from_elements_1424, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1425 = tensor.dim %4089, %c0 : tensor + %4090 = arith.index_cast %dim_1425 : index to i64 + %from_elements_1426 = tensor.from_elements %4090, %c4096_i64 : tensor<2xi64> + %4091 = stablehlo.real_dynamic_slice %4084, %c_22, %from_elements_1426, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1427 = tensor.from_elements %4090, %c4096_i64, %c1_i64 : tensor<3xi64> + %4092 = stablehlo.dynamic_reshape %4089, %from_elements_1427 : (tensor, tensor<3xi64>) -> tensor + %4093 = stablehlo.dynamic_iota %from_elements_1427, dim = 1 : (tensor<3xi64>) -> tensor + %4094 = stablehlo.concatenate %4092, %4093, dim = 2 : (tensor, tensor) -> tensor + %4095 = "stablehlo.scatter"(%4032, %4094, %4091) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4096 = stablehlo.slice %3841 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4097 = stablehlo.reshape %4096 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4098 = stablehlo.custom_call @byteir.non_zero(%4097) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1428 = tensor.dim %4098, %c0 : tensor + %4099 = arith.index_cast %dim_1428 : index to i64 + %from_elements_1429 = tensor.from_elements %4099, %c1_i64 : tensor<2xi64> + %4100 = stablehlo.real_dynamic_slice %4098, %c_22, %from_elements_1429, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1430 = tensor.dim %4100, %c0 : tensor + %4101 = arith.index_cast %dim_1430 : index to i64 + %from_elements_1431 = tensor.from_elements %4101 : tensor<1xi64> + %4102 = stablehlo.dynamic_reshape %4100, %from_elements_1431 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1432 = tensor.from_elements %4099, %c2_i64 : tensor<2xi64> + %4103 = stablehlo.real_dynamic_slice %4098, %c_24, %from_elements_1432, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1433 = tensor.dim %4103, %c0 : tensor + %4104 = arith.index_cast %dim_1433 : index to i64 + %from_elements_1434 = tensor.from_elements %4104 : tensor<1xi64> + %4105 = stablehlo.dynamic_reshape %4103, %from_elements_1434 : (tensor, tensor<1xi64>) -> tensor + %dim_1435 = tensor.dim %4105, %c0 : tensor + %4106 = arith.index_cast %dim_1435 : index to i64 + %from_elements_1436 = tensor.from_elements %4106, %c1_i64 : tensor<2xi64> + %4107 = stablehlo.dynamic_reshape %4105, %from_elements_1436 : (tensor, tensor<2xi64>) -> tensor + %dim_1437 = tensor.dim %4107, %c0 : tensor + %4108 = arith.index_cast %dim_1437 : index to i64 + %from_elements_1438 = tensor.from_elements %c1_i64, %4108, %c4096_i64 : tensor<3xi64> + %4109 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1438, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1439 = tensor.dim %4109, %c1 : tensor<1x?x4096xi64> + %4110 = arith.index_cast %dim_1439 : index to i64 + %from_elements_1440 = tensor.from_elements %c1_i64, %4110, %c4096_i64, %c1_i64 : tensor<4xi64> + %4111 = stablehlo.dynamic_reshape %4109, %from_elements_1440 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4112 = stablehlo.dynamic_broadcast_in_dim %4107, %from_elements_1438, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1441 = tensor.dim %4112, %c1 : tensor<1x?x4096xi64> + %4113 = arith.index_cast %dim_1441 : index to i64 + %from_elements_1442 = tensor.from_elements %c1_i64, %4113, %c4096_i64, %c1_i64 : tensor<4xi64> + %4114 = stablehlo.dynamic_reshape %4112, %from_elements_1442 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4115 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1438, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1443 = tensor.dim %4115, %c1 : tensor<1x?x4096xi64> + %4116 = arith.index_cast %dim_1443 : index to i64 + %from_elements_1444 = tensor.from_elements %c1_i64, %4116, %c4096_i64, %c1_i64 : tensor<4xi64> + %4117 = stablehlo.dynamic_reshape %4115, %from_elements_1444 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4118 = stablehlo.concatenate %4111, %4114, %4117, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4119 = "stablehlo.gather"(%3852, %4118) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4120 = shape.shape_of %4119 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4121 = shape.num_elements %4120 : tensor<3xindex> -> index + %4122 = stablehlo.compute_reshape_shape %4121, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4123 = stablehlo.dynamic_reshape %4119, %4122 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4124 = stablehlo.dot %4123, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4125 = stablehlo.logistic %4124 : tensor + %4126 = shape.shape_of %4125 : tensor -> tensor<2xindex> + %4127 = shape.shape_of %4124 : tensor -> tensor<2xindex> + %4128 = shape.cstr_broadcastable %4126, %4127 : tensor<2xindex>, tensor<2xindex> + %4129 = shape.assuming %4128 -> (tensor) { + %19688 = shape.broadcast %4126, %4127 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4125, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4124, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4130 = shape.shape_of %4129 : tensor -> tensor<2xindex> + %4131 = shape.cstr_broadcastable %4130, %4127 : tensor<2xindex>, tensor<2xindex> + %4132 = shape.assuming %4131 -> (tensor) { + %19688 = shape.broadcast %4130, %4127 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4129, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4124, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4133 = stablehlo.dot %4132, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1445 = tensor.dim %4105, %c0 : tensor + %4134 = arith.index_cast %dim_1445 : index to i64 + %from_elements_1446 = tensor.from_elements %4134, %c1_i64 : tensor<2xi64> + %4135 = stablehlo.dynamic_reshape %4105, %from_elements_1446 : (tensor, tensor<2xi64>) -> tensor + %dim_1447 = tensor.dim %4102, %c0 : tensor + %4136 = arith.index_cast %dim_1447 : index to i64 + %from_elements_1448 = tensor.from_elements %4136, %c1_i64 : tensor<2xi64> + %4137 = stablehlo.dynamic_reshape %4102, %from_elements_1448 : (tensor, tensor<2xi64>) -> tensor + %4138 = stablehlo.concatenate %4135, %4137, dim = 1 : (tensor, tensor) -> tensor + %4139 = "stablehlo.gather"(%3881, %4138) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4140 = shape.shape_of %4133 : tensor -> tensor<2xindex> + %4141 = shape.shape_of %4139 : tensor -> tensor<2xindex> + %4142 = shape.cstr_broadcastable %4140, %4141 : tensor<2xindex>, tensor<2xindex> + %4143 = shape.assuming %4142 -> (tensor) { + %19688 = shape.broadcast %4140, %4141 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4133, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4139, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4144 = shape.shape_of %4143 : tensor -> tensor<2xindex> + %4145 = stablehlo.dynamic_broadcast_in_dim %4143, %4144, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4146 = stablehlo.dynamic_broadcast_in_dim %213, %4144, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4147 = stablehlo.multiply %4145, %4146 : tensor + %dim_1449 = tensor.dim %4107, %c0 : tensor + %4148 = arith.index_cast %dim_1449 : index to i64 + %dim_1450 = tensor.dim %4143, %c0 : tensor + %4149 = arith.index_cast %dim_1450 : index to i64 + %4150 = arith.maxsi %4148, %4149 : i64 + %4151 = arith.index_cast %4150 : i64 to index + %from_elements_1451 = tensor.from_elements %4151, %c4096 : tensor<2xindex> + %4152 = stablehlo.dynamic_broadcast_in_dim %4107, %from_elements_1451, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1452 = tensor.dim %4152, %c0 : tensor + %4153 = arith.index_cast %dim_1452 : index to i64 + %from_elements_1453 = tensor.from_elements %4153, %c4096_i64 : tensor<2xi64> + %4154 = stablehlo.real_dynamic_slice %4147, %c_22, %from_elements_1453, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1454 = tensor.from_elements %4153, %c4096_i64, %c1_i64 : tensor<3xi64> + %4155 = stablehlo.dynamic_reshape %4152, %from_elements_1454 : (tensor, tensor<3xi64>) -> tensor + %4156 = stablehlo.dynamic_iota %from_elements_1454, dim = 1 : (tensor<3xi64>) -> tensor + %4157 = stablehlo.concatenate %4155, %4156, dim = 2 : (tensor, tensor) -> tensor + %4158 = "stablehlo.scatter"(%4095, %4157, %4154) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4159 = stablehlo.slice %3841 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4160 = stablehlo.reshape %4159 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4161 = stablehlo.custom_call @byteir.non_zero(%4160) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1455 = tensor.dim %4161, %c0 : tensor + %4162 = arith.index_cast %dim_1455 : index to i64 + %from_elements_1456 = tensor.from_elements %4162, %c1_i64 : tensor<2xi64> + %4163 = stablehlo.real_dynamic_slice %4161, %c_22, %from_elements_1456, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1457 = tensor.dim %4163, %c0 : tensor + %4164 = arith.index_cast %dim_1457 : index to i64 + %from_elements_1458 = tensor.from_elements %4164 : tensor<1xi64> + %4165 = stablehlo.dynamic_reshape %4163, %from_elements_1458 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1459 = tensor.from_elements %4162, %c2_i64 : tensor<2xi64> + %4166 = stablehlo.real_dynamic_slice %4161, %c_24, %from_elements_1459, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1460 = tensor.dim %4166, %c0 : tensor + %4167 = arith.index_cast %dim_1460 : index to i64 + %from_elements_1461 = tensor.from_elements %4167 : tensor<1xi64> + %4168 = stablehlo.dynamic_reshape %4166, %from_elements_1461 : (tensor, tensor<1xi64>) -> tensor + %dim_1462 = tensor.dim %4168, %c0 : tensor + %4169 = arith.index_cast %dim_1462 : index to i64 + %from_elements_1463 = tensor.from_elements %4169, %c1_i64 : tensor<2xi64> + %4170 = stablehlo.dynamic_reshape %4168, %from_elements_1463 : (tensor, tensor<2xi64>) -> tensor + %dim_1464 = tensor.dim %4170, %c0 : tensor + %4171 = arith.index_cast %dim_1464 : index to i64 + %from_elements_1465 = tensor.from_elements %c1_i64, %4171, %c4096_i64 : tensor<3xi64> + %4172 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1465, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1466 = tensor.dim %4172, %c1 : tensor<1x?x4096xi64> + %4173 = arith.index_cast %dim_1466 : index to i64 + %from_elements_1467 = tensor.from_elements %c1_i64, %4173, %c4096_i64, %c1_i64 : tensor<4xi64> + %4174 = stablehlo.dynamic_reshape %4172, %from_elements_1467 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4175 = stablehlo.dynamic_broadcast_in_dim %4170, %from_elements_1465, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1468 = tensor.dim %4175, %c1 : tensor<1x?x4096xi64> + %4176 = arith.index_cast %dim_1468 : index to i64 + %from_elements_1469 = tensor.from_elements %c1_i64, %4176, %c4096_i64, %c1_i64 : tensor<4xi64> + %4177 = stablehlo.dynamic_reshape %4175, %from_elements_1469 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4178 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1465, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1470 = tensor.dim %4178, %c1 : tensor<1x?x4096xi64> + %4179 = arith.index_cast %dim_1470 : index to i64 + %from_elements_1471 = tensor.from_elements %c1_i64, %4179, %c4096_i64, %c1_i64 : tensor<4xi64> + %4180 = stablehlo.dynamic_reshape %4178, %from_elements_1471 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4181 = stablehlo.concatenate %4174, %4177, %4180, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4182 = "stablehlo.gather"(%3852, %4181) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4183 = shape.shape_of %4182 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4184 = shape.num_elements %4183 : tensor<3xindex> -> index + %4185 = stablehlo.compute_reshape_shape %4184, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4186 = stablehlo.dynamic_reshape %4182, %4185 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4187 = stablehlo.dot %4186, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4188 = stablehlo.logistic %4187 : tensor + %4189 = shape.shape_of %4188 : tensor -> tensor<2xindex> + %4190 = shape.shape_of %4187 : tensor -> tensor<2xindex> + %4191 = shape.cstr_broadcastable %4189, %4190 : tensor<2xindex>, tensor<2xindex> + %4192 = shape.assuming %4191 -> (tensor) { + %19688 = shape.broadcast %4189, %4190 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4188, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4187, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4193 = shape.shape_of %4192 : tensor -> tensor<2xindex> + %4194 = shape.cstr_broadcastable %4193, %4190 : tensor<2xindex>, tensor<2xindex> + %4195 = shape.assuming %4194 -> (tensor) { + %19688 = shape.broadcast %4193, %4190 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4192, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4187, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4196 = stablehlo.dot %4195, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1472 = tensor.dim %4168, %c0 : tensor + %4197 = arith.index_cast %dim_1472 : index to i64 + %from_elements_1473 = tensor.from_elements %4197, %c1_i64 : tensor<2xi64> + %4198 = stablehlo.dynamic_reshape %4168, %from_elements_1473 : (tensor, tensor<2xi64>) -> tensor + %dim_1474 = tensor.dim %4165, %c0 : tensor + %4199 = arith.index_cast %dim_1474 : index to i64 + %from_elements_1475 = tensor.from_elements %4199, %c1_i64 : tensor<2xi64> + %4200 = stablehlo.dynamic_reshape %4165, %from_elements_1475 : (tensor, tensor<2xi64>) -> tensor + %4201 = stablehlo.concatenate %4198, %4200, dim = 1 : (tensor, tensor) -> tensor + %4202 = "stablehlo.gather"(%3881, %4201) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4203 = shape.shape_of %4196 : tensor -> tensor<2xindex> + %4204 = shape.shape_of %4202 : tensor -> tensor<2xindex> + %4205 = shape.cstr_broadcastable %4203, %4204 : tensor<2xindex>, tensor<2xindex> + %4206 = shape.assuming %4205 -> (tensor) { + %19688 = shape.broadcast %4203, %4204 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4196, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4202, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4207 = shape.shape_of %4206 : tensor -> tensor<2xindex> + %4208 = stablehlo.dynamic_broadcast_in_dim %4206, %4207, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4209 = stablehlo.dynamic_broadcast_in_dim %213, %4207, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4210 = stablehlo.multiply %4208, %4209 : tensor + %dim_1476 = tensor.dim %4170, %c0 : tensor + %4211 = arith.index_cast %dim_1476 : index to i64 + %dim_1477 = tensor.dim %4206, %c0 : tensor + %4212 = arith.index_cast %dim_1477 : index to i64 + %4213 = arith.maxsi %4211, %4212 : i64 + %4214 = arith.index_cast %4213 : i64 to index + %from_elements_1478 = tensor.from_elements %4214, %c4096 : tensor<2xindex> + %4215 = stablehlo.dynamic_broadcast_in_dim %4170, %from_elements_1478, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1479 = tensor.dim %4215, %c0 : tensor + %4216 = arith.index_cast %dim_1479 : index to i64 + %from_elements_1480 = tensor.from_elements %4216, %c4096_i64 : tensor<2xi64> + %4217 = stablehlo.real_dynamic_slice %4210, %c_22, %from_elements_1480, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1481 = tensor.from_elements %4216, %c4096_i64, %c1_i64 : tensor<3xi64> + %4218 = stablehlo.dynamic_reshape %4215, %from_elements_1481 : (tensor, tensor<3xi64>) -> tensor + %4219 = stablehlo.dynamic_iota %from_elements_1481, dim = 1 : (tensor<3xi64>) -> tensor + %4220 = stablehlo.concatenate %4218, %4219, dim = 2 : (tensor, tensor) -> tensor + %4221 = "stablehlo.scatter"(%4158, %4220, %4217) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4222 = stablehlo.slice %3841 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4223 = stablehlo.reshape %4222 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4224 = stablehlo.custom_call @byteir.non_zero(%4223) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1482 = tensor.dim %4224, %c0 : tensor + %4225 = arith.index_cast %dim_1482 : index to i64 + %from_elements_1483 = tensor.from_elements %4225, %c1_i64 : tensor<2xi64> + %4226 = stablehlo.real_dynamic_slice %4224, %c_22, %from_elements_1483, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1484 = tensor.dim %4226, %c0 : tensor + %4227 = arith.index_cast %dim_1484 : index to i64 + %from_elements_1485 = tensor.from_elements %4227 : tensor<1xi64> + %4228 = stablehlo.dynamic_reshape %4226, %from_elements_1485 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1486 = tensor.from_elements %4225, %c2_i64 : tensor<2xi64> + %4229 = stablehlo.real_dynamic_slice %4224, %c_24, %from_elements_1486, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1487 = tensor.dim %4229, %c0 : tensor + %4230 = arith.index_cast %dim_1487 : index to i64 + %from_elements_1488 = tensor.from_elements %4230 : tensor<1xi64> + %4231 = stablehlo.dynamic_reshape %4229, %from_elements_1488 : (tensor, tensor<1xi64>) -> tensor + %dim_1489 = tensor.dim %4231, %c0 : tensor + %4232 = arith.index_cast %dim_1489 : index to i64 + %from_elements_1490 = tensor.from_elements %4232, %c1_i64 : tensor<2xi64> + %4233 = stablehlo.dynamic_reshape %4231, %from_elements_1490 : (tensor, tensor<2xi64>) -> tensor + %dim_1491 = tensor.dim %4233, %c0 : tensor + %4234 = arith.index_cast %dim_1491 : index to i64 + %from_elements_1492 = tensor.from_elements %c1_i64, %4234, %c4096_i64 : tensor<3xi64> + %4235 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1492, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1493 = tensor.dim %4235, %c1 : tensor<1x?x4096xi64> + %4236 = arith.index_cast %dim_1493 : index to i64 + %from_elements_1494 = tensor.from_elements %c1_i64, %4236, %c4096_i64, %c1_i64 : tensor<4xi64> + %4237 = stablehlo.dynamic_reshape %4235, %from_elements_1494 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4238 = stablehlo.dynamic_broadcast_in_dim %4233, %from_elements_1492, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1495 = tensor.dim %4238, %c1 : tensor<1x?x4096xi64> + %4239 = arith.index_cast %dim_1495 : index to i64 + %from_elements_1496 = tensor.from_elements %c1_i64, %4239, %c4096_i64, %c1_i64 : tensor<4xi64> + %4240 = stablehlo.dynamic_reshape %4238, %from_elements_1496 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4241 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1492, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1497 = tensor.dim %4241, %c1 : tensor<1x?x4096xi64> + %4242 = arith.index_cast %dim_1497 : index to i64 + %from_elements_1498 = tensor.from_elements %c1_i64, %4242, %c4096_i64, %c1_i64 : tensor<4xi64> + %4243 = stablehlo.dynamic_reshape %4241, %from_elements_1498 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4244 = stablehlo.concatenate %4237, %4240, %4243, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4245 = "stablehlo.gather"(%3852, %4244) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4246 = shape.shape_of %4245 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4247 = shape.num_elements %4246 : tensor<3xindex> -> index + %4248 = stablehlo.compute_reshape_shape %4247, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4249 = stablehlo.dynamic_reshape %4245, %4248 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4250 = stablehlo.dot %4249, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4251 = stablehlo.logistic %4250 : tensor + %4252 = shape.shape_of %4251 : tensor -> tensor<2xindex> + %4253 = shape.shape_of %4250 : tensor -> tensor<2xindex> + %4254 = shape.cstr_broadcastable %4252, %4253 : tensor<2xindex>, tensor<2xindex> + %4255 = shape.assuming %4254 -> (tensor) { + %19688 = shape.broadcast %4252, %4253 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4251, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4250, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4256 = shape.shape_of %4255 : tensor -> tensor<2xindex> + %4257 = shape.cstr_broadcastable %4256, %4253 : tensor<2xindex>, tensor<2xindex> + %4258 = shape.assuming %4257 -> (tensor) { + %19688 = shape.broadcast %4256, %4253 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4255, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4250, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4259 = stablehlo.dot %4258, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1499 = tensor.dim %4231, %c0 : tensor + %4260 = arith.index_cast %dim_1499 : index to i64 + %from_elements_1500 = tensor.from_elements %4260, %c1_i64 : tensor<2xi64> + %4261 = stablehlo.dynamic_reshape %4231, %from_elements_1500 : (tensor, tensor<2xi64>) -> tensor + %dim_1501 = tensor.dim %4228, %c0 : tensor + %4262 = arith.index_cast %dim_1501 : index to i64 + %from_elements_1502 = tensor.from_elements %4262, %c1_i64 : tensor<2xi64> + %4263 = stablehlo.dynamic_reshape %4228, %from_elements_1502 : (tensor, tensor<2xi64>) -> tensor + %4264 = stablehlo.concatenate %4261, %4263, dim = 1 : (tensor, tensor) -> tensor + %4265 = "stablehlo.gather"(%3881, %4264) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4266 = shape.shape_of %4259 : tensor -> tensor<2xindex> + %4267 = shape.shape_of %4265 : tensor -> tensor<2xindex> + %4268 = shape.cstr_broadcastable %4266, %4267 : tensor<2xindex>, tensor<2xindex> + %4269 = shape.assuming %4268 -> (tensor) { + %19688 = shape.broadcast %4266, %4267 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4259, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4265, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4270 = shape.shape_of %4269 : tensor -> tensor<2xindex> + %4271 = stablehlo.dynamic_broadcast_in_dim %4269, %4270, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4272 = stablehlo.dynamic_broadcast_in_dim %213, %4270, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4273 = stablehlo.multiply %4271, %4272 : tensor + %dim_1503 = tensor.dim %4233, %c0 : tensor + %4274 = arith.index_cast %dim_1503 : index to i64 + %dim_1504 = tensor.dim %4269, %c0 : tensor + %4275 = arith.index_cast %dim_1504 : index to i64 + %4276 = arith.maxsi %4274, %4275 : i64 + %4277 = arith.index_cast %4276 : i64 to index + %from_elements_1505 = tensor.from_elements %4277, %c4096 : tensor<2xindex> + %4278 = stablehlo.dynamic_broadcast_in_dim %4233, %from_elements_1505, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1506 = tensor.dim %4278, %c0 : tensor + %4279 = arith.index_cast %dim_1506 : index to i64 + %from_elements_1507 = tensor.from_elements %4279, %c4096_i64 : tensor<2xi64> + %4280 = stablehlo.real_dynamic_slice %4273, %c_22, %from_elements_1507, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1508 = tensor.from_elements %4279, %c4096_i64, %c1_i64 : tensor<3xi64> + %4281 = stablehlo.dynamic_reshape %4278, %from_elements_1508 : (tensor, tensor<3xi64>) -> tensor + %4282 = stablehlo.dynamic_iota %from_elements_1508, dim = 1 : (tensor<3xi64>) -> tensor + %4283 = stablehlo.concatenate %4281, %4282, dim = 2 : (tensor, tensor) -> tensor + %4284 = "stablehlo.scatter"(%4221, %4283, %4280) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4285 = stablehlo.slice %3841 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4286 = stablehlo.reshape %4285 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4287 = stablehlo.custom_call @byteir.non_zero(%4286) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1509 = tensor.dim %4287, %c0 : tensor + %4288 = arith.index_cast %dim_1509 : index to i64 + %from_elements_1510 = tensor.from_elements %4288, %c1_i64 : tensor<2xi64> + %4289 = stablehlo.real_dynamic_slice %4287, %c_22, %from_elements_1510, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1511 = tensor.dim %4289, %c0 : tensor + %4290 = arith.index_cast %dim_1511 : index to i64 + %from_elements_1512 = tensor.from_elements %4290 : tensor<1xi64> + %4291 = stablehlo.dynamic_reshape %4289, %from_elements_1512 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1513 = tensor.from_elements %4288, %c2_i64 : tensor<2xi64> + %4292 = stablehlo.real_dynamic_slice %4287, %c_24, %from_elements_1513, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1514 = tensor.dim %4292, %c0 : tensor + %4293 = arith.index_cast %dim_1514 : index to i64 + %from_elements_1515 = tensor.from_elements %4293 : tensor<1xi64> + %4294 = stablehlo.dynamic_reshape %4292, %from_elements_1515 : (tensor, tensor<1xi64>) -> tensor + %dim_1516 = tensor.dim %4294, %c0 : tensor + %4295 = arith.index_cast %dim_1516 : index to i64 + %from_elements_1517 = tensor.from_elements %4295, %c1_i64 : tensor<2xi64> + %4296 = stablehlo.dynamic_reshape %4294, %from_elements_1517 : (tensor, tensor<2xi64>) -> tensor + %dim_1518 = tensor.dim %4296, %c0 : tensor + %4297 = arith.index_cast %dim_1518 : index to i64 + %from_elements_1519 = tensor.from_elements %c1_i64, %4297, %c4096_i64 : tensor<3xi64> + %4298 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1519, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1520 = tensor.dim %4298, %c1 : tensor<1x?x4096xi64> + %4299 = arith.index_cast %dim_1520 : index to i64 + %from_elements_1521 = tensor.from_elements %c1_i64, %4299, %c4096_i64, %c1_i64 : tensor<4xi64> + %4300 = stablehlo.dynamic_reshape %4298, %from_elements_1521 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4301 = stablehlo.dynamic_broadcast_in_dim %4296, %from_elements_1519, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1522 = tensor.dim %4301, %c1 : tensor<1x?x4096xi64> + %4302 = arith.index_cast %dim_1522 : index to i64 + %from_elements_1523 = tensor.from_elements %c1_i64, %4302, %c4096_i64, %c1_i64 : tensor<4xi64> + %4303 = stablehlo.dynamic_reshape %4301, %from_elements_1523 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4304 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1519, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1524 = tensor.dim %4304, %c1 : tensor<1x?x4096xi64> + %4305 = arith.index_cast %dim_1524 : index to i64 + %from_elements_1525 = tensor.from_elements %c1_i64, %4305, %c4096_i64, %c1_i64 : tensor<4xi64> + %4306 = stablehlo.dynamic_reshape %4304, %from_elements_1525 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4307 = stablehlo.concatenate %4300, %4303, %4306, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4308 = "stablehlo.gather"(%3852, %4307) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4309 = shape.shape_of %4308 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4310 = shape.num_elements %4309 : tensor<3xindex> -> index + %4311 = stablehlo.compute_reshape_shape %4310, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4312 = stablehlo.dynamic_reshape %4308, %4311 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4313 = stablehlo.dot %4312, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4314 = stablehlo.logistic %4313 : tensor + %4315 = shape.shape_of %4314 : tensor -> tensor<2xindex> + %4316 = shape.shape_of %4313 : tensor -> tensor<2xindex> + %4317 = shape.cstr_broadcastable %4315, %4316 : tensor<2xindex>, tensor<2xindex> + %4318 = shape.assuming %4317 -> (tensor) { + %19688 = shape.broadcast %4315, %4316 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4314, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4313, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4319 = shape.shape_of %4318 : tensor -> tensor<2xindex> + %4320 = shape.cstr_broadcastable %4319, %4316 : tensor<2xindex>, tensor<2xindex> + %4321 = shape.assuming %4320 -> (tensor) { + %19688 = shape.broadcast %4319, %4316 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4318, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4313, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4322 = stablehlo.dot %4321, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1526 = tensor.dim %4294, %c0 : tensor + %4323 = arith.index_cast %dim_1526 : index to i64 + %from_elements_1527 = tensor.from_elements %4323, %c1_i64 : tensor<2xi64> + %4324 = stablehlo.dynamic_reshape %4294, %from_elements_1527 : (tensor, tensor<2xi64>) -> tensor + %dim_1528 = tensor.dim %4291, %c0 : tensor + %4325 = arith.index_cast %dim_1528 : index to i64 + %from_elements_1529 = tensor.from_elements %4325, %c1_i64 : tensor<2xi64> + %4326 = stablehlo.dynamic_reshape %4291, %from_elements_1529 : (tensor, tensor<2xi64>) -> tensor + %4327 = stablehlo.concatenate %4324, %4326, dim = 1 : (tensor, tensor) -> tensor + %4328 = "stablehlo.gather"(%3881, %4327) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4329 = shape.shape_of %4322 : tensor -> tensor<2xindex> + %4330 = shape.shape_of %4328 : tensor -> tensor<2xindex> + %4331 = shape.cstr_broadcastable %4329, %4330 : tensor<2xindex>, tensor<2xindex> + %4332 = shape.assuming %4331 -> (tensor) { + %19688 = shape.broadcast %4329, %4330 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4322, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4328, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4333 = shape.shape_of %4332 : tensor -> tensor<2xindex> + %4334 = stablehlo.dynamic_broadcast_in_dim %4332, %4333, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4335 = stablehlo.dynamic_broadcast_in_dim %213, %4333, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4336 = stablehlo.multiply %4334, %4335 : tensor + %dim_1530 = tensor.dim %4296, %c0 : tensor + %4337 = arith.index_cast %dim_1530 : index to i64 + %dim_1531 = tensor.dim %4332, %c0 : tensor + %4338 = arith.index_cast %dim_1531 : index to i64 + %4339 = arith.maxsi %4337, %4338 : i64 + %4340 = arith.index_cast %4339 : i64 to index + %from_elements_1532 = tensor.from_elements %4340, %c4096 : tensor<2xindex> + %4341 = stablehlo.dynamic_broadcast_in_dim %4296, %from_elements_1532, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1533 = tensor.dim %4341, %c0 : tensor + %4342 = arith.index_cast %dim_1533 : index to i64 + %from_elements_1534 = tensor.from_elements %4342, %c4096_i64 : tensor<2xi64> + %4343 = stablehlo.real_dynamic_slice %4336, %c_22, %from_elements_1534, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1535 = tensor.from_elements %4342, %c4096_i64, %c1_i64 : tensor<3xi64> + %4344 = stablehlo.dynamic_reshape %4341, %from_elements_1535 : (tensor, tensor<3xi64>) -> tensor + %4345 = stablehlo.dynamic_iota %from_elements_1535, dim = 1 : (tensor<3xi64>) -> tensor + %4346 = stablehlo.concatenate %4344, %4345, dim = 2 : (tensor, tensor) -> tensor + %4347 = "stablehlo.scatter"(%4284, %4346, %4343) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4348 = stablehlo.reshape %4347 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %4349 = stablehlo.add %3814, %4348 : tensor<3x1x4096xf32> + %4350 = stablehlo.broadcast_in_dim %4349, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %4351 = stablehlo.power %4350, %15 : tensor<3x1x4096xf32> + %4352 = stablehlo.reduce(%4351 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %4353 = stablehlo.reshape %4352 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %4354 = stablehlo.broadcast_in_dim %4353, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %4355 = stablehlo.divide %4354, %21 : tensor<3x1x1xf32> + %4356 = stablehlo.broadcast_in_dim %4355, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %4357 = stablehlo.add %4356, %25 : tensor<3x1x1xf32> + %4358 = stablehlo.rsqrt %4357 : tensor<3x1x1xf32> + %4359 = stablehlo.broadcast_in_dim %4358, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %4360 = stablehlo.multiply %4350, %4359 : tensor<3x1x4096xf32> + %4361 = stablehlo.broadcast_in_dim %4360, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %4362 = stablehlo.multiply %4361, %31 : tensor<3x1x4096xf32> + %4363 = stablehlo.reshape %4362 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %4364 = stablehlo.dot %4363, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %4365 = stablehlo.reshape %4364 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %4366 = stablehlo.dot %4363, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %4367 = stablehlo.reshape %4366 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %4368 = stablehlo.reshape %4365 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %4369 = stablehlo.transpose %4368, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %4370 = stablehlo.reshape %4367 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %4371 = stablehlo.transpose %4370, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %4372 = stablehlo.slice %arg14 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %4373 = stablehlo.slice %arg15 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %4374 = "stablehlo.gather"(%4372, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %4375 = stablehlo.reshape %4374 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %4376 = "stablehlo.gather"(%4373, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %4377 = stablehlo.reshape %4376 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %4378 = stablehlo.broadcast_in_dim %4369, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %4379 = stablehlo.broadcast_in_dim %4375, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %4380 = stablehlo.multiply %4378, %4379 : tensor<3x32x1x128xf32> + %4381 = stablehlo.slice %4369 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %4382 = stablehlo.slice %4369 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %4383 = stablehlo.negate %4382 : tensor<3x32x1x64xf32> + %4384 = stablehlo.concatenate %4383, %4381, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %4385 = stablehlo.broadcast_in_dim %4384, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %4386 = stablehlo.broadcast_in_dim %4377, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %4387 = stablehlo.multiply %4385, %4386 : tensor<3x32x1x128xf32> + %4388 = stablehlo.add %4380, %4387 : tensor<3x32x1x128xf32> + %4389 = stablehlo.broadcast_in_dim %4371, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %4390 = stablehlo.broadcast_in_dim %4375, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %4391 = stablehlo.multiply %4389, %4390 : tensor<3x8x1x128xf32> + %4392 = stablehlo.slice %4371 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %4393 = stablehlo.slice %4371 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %4394 = stablehlo.negate %4393 : tensor<3x8x1x64xf32> + %4395 = stablehlo.concatenate %4394, %4392, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %4396 = stablehlo.broadcast_in_dim %4395, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %4397 = stablehlo.broadcast_in_dim %4377, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %4398 = stablehlo.multiply %4396, %4397 : tensor<3x8x1x128xf32> + %4399 = stablehlo.add %4391, %4398 : tensor<3x8x1x128xf32> + %4400 = stablehlo.concatenate %arg79, %4399, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %4401 = stablehlo.concatenate %arg80, %4371, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %4402 = stablehlo.reshape %4400 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %4403 = stablehlo.broadcast_in_dim %4402, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %4404 = stablehlo.reshape %4403 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %4405 = stablehlo.reshape %4401 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %4406 = stablehlo.broadcast_in_dim %4405, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %4407 = stablehlo.reshape %4406 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %4408 = stablehlo.transpose %4404, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %4409 = stablehlo.reshape %4388 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %4410 = stablehlo.reshape %4408 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %4411 = stablehlo.broadcast_in_dim %4410, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %4412 = stablehlo.dot_general %4409, %4411, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %4413 = stablehlo.reshape %4412 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %4414 = stablehlo.broadcast_in_dim %4413, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %4415 = stablehlo.divide %4414, %89 : tensor<3x32x1x8xf32> + %4416 = stablehlo.custom_call @byteir.softmax(%4415) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %4417 = stablehlo.reshape %4416 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %4418 = stablehlo.reshape %4407 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %4419 = stablehlo.broadcast_in_dim %4418, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %4420 = stablehlo.dot_general %4417, %4419, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %4421 = stablehlo.reshape %4420 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %4422 = stablehlo.transpose %4421, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %4423 = stablehlo.reshape %4422 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %4424 = stablehlo.reshape %4423 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %4425 = stablehlo.dot %4424, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %4426 = stablehlo.reshape %4425 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %4427 = stablehlo.add %4349, %4426 : tensor<3x1x4096xf32> + %4428 = stablehlo.broadcast_in_dim %4427, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %4429 = stablehlo.power %4428, %15 : tensor<3x1x4096xf32> + %4430 = stablehlo.reduce(%4429 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %4431 = stablehlo.reshape %4430 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %4432 = stablehlo.broadcast_in_dim %4431, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %4433 = stablehlo.divide %4432, %21 : tensor<3x1x1xf32> + %4434 = stablehlo.broadcast_in_dim %4433, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %4435 = stablehlo.add %4434, %25 : tensor<3x1x1xf32> + %4436 = stablehlo.rsqrt %4435 : tensor<3x1x1xf32> + %4437 = stablehlo.broadcast_in_dim %4436, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %4438 = stablehlo.multiply %4428, %4437 : tensor<3x1x4096xf32> + %4439 = stablehlo.broadcast_in_dim %4438, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %4440 = stablehlo.multiply %4439, %31 : tensor<3x1x4096xf32> + %4441 = stablehlo.reshape %4440 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %4442 = stablehlo.dot %4441, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %4443 = stablehlo.custom_call @byteir.softmax(%4442) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %4444:2 = stablehlo.custom_call @byteir.top_k(%4443) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %4445 = stablehlo.reduce(%4444#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %4446 = stablehlo.reshape %4445 : (tensor<3xf32>) -> tensor<3x1xf32> + %4447 = stablehlo.broadcast_in_dim %4444#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %4448 = stablehlo.broadcast_in_dim %4446, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %4449 = stablehlo.divide %4447, %4448 : tensor<3x2xf32> + %4450 = stablehlo.reshape %4444#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %4451 = stablehlo.broadcast_in_dim %4450, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %4452 = stablehlo.compare EQ, %4451, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %4453 = stablehlo.convert %4452 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %4454 = stablehlo.transpose %4453, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %4455 = stablehlo.slice %4454 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4456 = stablehlo.reshape %4455 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4457 = stablehlo.custom_call @byteir.non_zero(%4456) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1536 = tensor.dim %4457, %c0 : tensor + %4458 = arith.index_cast %dim_1536 : index to i64 + %from_elements_1537 = tensor.from_elements %4458, %c1_i64 : tensor<2xi64> + %4459 = stablehlo.real_dynamic_slice %4457, %c_22, %from_elements_1537, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1538 = tensor.dim %4459, %c0 : tensor + %4460 = arith.index_cast %dim_1538 : index to i64 + %from_elements_1539 = tensor.from_elements %4460 : tensor<1xi64> + %4461 = stablehlo.dynamic_reshape %4459, %from_elements_1539 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1540 = tensor.from_elements %4458, %c2_i64 : tensor<2xi64> + %4462 = stablehlo.real_dynamic_slice %4457, %c_24, %from_elements_1540, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1541 = tensor.dim %4462, %c0 : tensor + %4463 = arith.index_cast %dim_1541 : index to i64 + %from_elements_1542 = tensor.from_elements %4463 : tensor<1xi64> + %4464 = stablehlo.dynamic_reshape %4462, %from_elements_1542 : (tensor, tensor<1xi64>) -> tensor + %4465 = stablehlo.reshape %4441 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_1543 = tensor.dim %4464, %c0 : tensor + %4466 = arith.index_cast %dim_1543 : index to i64 + %from_elements_1544 = tensor.from_elements %4466, %c1_i64 : tensor<2xi64> + %4467 = stablehlo.dynamic_reshape %4464, %from_elements_1544 : (tensor, tensor<2xi64>) -> tensor + %dim_1545 = tensor.dim %4467, %c0 : tensor + %4468 = arith.index_cast %dim_1545 : index to i64 + %from_elements_1546 = tensor.from_elements %c1_i64, %4468, %c4096_i64 : tensor<3xi64> + %4469 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1546, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1547 = tensor.dim %4469, %c1 : tensor<1x?x4096xi64> + %4470 = arith.index_cast %dim_1547 : index to i64 + %from_elements_1548 = tensor.from_elements %c1_i64, %4470, %c4096_i64, %c1_i64 : tensor<4xi64> + %4471 = stablehlo.dynamic_reshape %4469, %from_elements_1548 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4472 = stablehlo.dynamic_broadcast_in_dim %4467, %from_elements_1546, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1549 = tensor.dim %4472, %c1 : tensor<1x?x4096xi64> + %4473 = arith.index_cast %dim_1549 : index to i64 + %from_elements_1550 = tensor.from_elements %c1_i64, %4473, %c4096_i64, %c1_i64 : tensor<4xi64> + %4474 = stablehlo.dynamic_reshape %4472, %from_elements_1550 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4475 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1546, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1551 = tensor.dim %4475, %c1 : tensor<1x?x4096xi64> + %4476 = arith.index_cast %dim_1551 : index to i64 + %from_elements_1552 = tensor.from_elements %c1_i64, %4476, %c4096_i64, %c1_i64 : tensor<4xi64> + %4477 = stablehlo.dynamic_reshape %4475, %from_elements_1552 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4478 = stablehlo.concatenate %4471, %4474, %4477, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4479 = "stablehlo.gather"(%4465, %4478) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4480 = shape.shape_of %4479 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4481 = shape.num_elements %4480 : tensor<3xindex> -> index + %4482 = stablehlo.compute_reshape_shape %4481, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4483 = stablehlo.dynamic_reshape %4479, %4482 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4484 = stablehlo.dot %4483, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4485 = stablehlo.logistic %4484 : tensor + %4486 = shape.shape_of %4485 : tensor -> tensor<2xindex> + %4487 = shape.shape_of %4484 : tensor -> tensor<2xindex> + %4488 = shape.cstr_broadcastable %4486, %4487 : tensor<2xindex>, tensor<2xindex> + %4489 = shape.assuming %4488 -> (tensor) { + %19688 = shape.broadcast %4486, %4487 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4485, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4484, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4490 = shape.shape_of %4489 : tensor -> tensor<2xindex> + %4491 = shape.cstr_broadcastable %4490, %4487 : tensor<2xindex>, tensor<2xindex> + %4492 = shape.assuming %4491 -> (tensor) { + %19688 = shape.broadcast %4490, %4487 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4489, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4484, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4493 = stablehlo.dot %4492, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %4494 = stablehlo.reshape %4449 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_1553 = tensor.dim %4464, %c0 : tensor + %4495 = arith.index_cast %dim_1553 : index to i64 + %from_elements_1554 = tensor.from_elements %4495, %c1_i64 : tensor<2xi64> + %4496 = stablehlo.dynamic_reshape %4464, %from_elements_1554 : (tensor, tensor<2xi64>) -> tensor + %dim_1555 = tensor.dim %4461, %c0 : tensor + %4497 = arith.index_cast %dim_1555 : index to i64 + %from_elements_1556 = tensor.from_elements %4497, %c1_i64 : tensor<2xi64> + %4498 = stablehlo.dynamic_reshape %4461, %from_elements_1556 : (tensor, tensor<2xi64>) -> tensor + %4499 = stablehlo.concatenate %4496, %4498, dim = 1 : (tensor, tensor) -> tensor + %4500 = "stablehlo.gather"(%4494, %4499) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4501 = shape.shape_of %4493 : tensor -> tensor<2xindex> + %4502 = shape.shape_of %4500 : tensor -> tensor<2xindex> + %4503 = shape.cstr_broadcastable %4501, %4502 : tensor<2xindex>, tensor<2xindex> + %4504 = shape.assuming %4503 -> (tensor) { + %19688 = shape.broadcast %4501, %4502 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4493, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4500, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4505 = shape.shape_of %4504 : tensor -> tensor<2xindex> + %4506 = stablehlo.dynamic_broadcast_in_dim %4504, %4505, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4507 = stablehlo.dynamic_broadcast_in_dim %213, %4505, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4508 = stablehlo.multiply %4506, %4507 : tensor + %dim_1557 = tensor.dim %4467, %c0 : tensor + %4509 = arith.index_cast %dim_1557 : index to i64 + %dim_1558 = tensor.dim %4504, %c0 : tensor + %4510 = arith.index_cast %dim_1558 : index to i64 + %4511 = arith.maxsi %4509, %4510 : i64 + %4512 = arith.index_cast %4511 : i64 to index + %from_elements_1559 = tensor.from_elements %4512, %c4096 : tensor<2xindex> + %4513 = stablehlo.dynamic_broadcast_in_dim %4467, %from_elements_1559, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1560 = tensor.dim %4513, %c0 : tensor + %4514 = arith.index_cast %dim_1560 : index to i64 + %from_elements_1561 = tensor.from_elements %4514, %c4096_i64 : tensor<2xi64> + %4515 = stablehlo.real_dynamic_slice %4508, %c_22, %from_elements_1561, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1562 = tensor.from_elements %4514, %c4096_i64, %c1_i64 : tensor<3xi64> + %4516 = stablehlo.dynamic_reshape %4513, %from_elements_1562 : (tensor, tensor<3xi64>) -> tensor + %4517 = stablehlo.dynamic_iota %from_elements_1562, dim = 1 : (tensor<3xi64>) -> tensor + %4518 = stablehlo.concatenate %4516, %4517, dim = 2 : (tensor, tensor) -> tensor + %4519 = "stablehlo.scatter"(%cst_2, %4518, %4515) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4520 = stablehlo.slice %4454 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4521 = stablehlo.reshape %4520 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4522 = stablehlo.custom_call @byteir.non_zero(%4521) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1563 = tensor.dim %4522, %c0 : tensor + %4523 = arith.index_cast %dim_1563 : index to i64 + %from_elements_1564 = tensor.from_elements %4523, %c1_i64 : tensor<2xi64> + %4524 = stablehlo.real_dynamic_slice %4522, %c_22, %from_elements_1564, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1565 = tensor.dim %4524, %c0 : tensor + %4525 = arith.index_cast %dim_1565 : index to i64 + %from_elements_1566 = tensor.from_elements %4525 : tensor<1xi64> + %4526 = stablehlo.dynamic_reshape %4524, %from_elements_1566 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1567 = tensor.from_elements %4523, %c2_i64 : tensor<2xi64> + %4527 = stablehlo.real_dynamic_slice %4522, %c_24, %from_elements_1567, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1568 = tensor.dim %4527, %c0 : tensor + %4528 = arith.index_cast %dim_1568 : index to i64 + %from_elements_1569 = tensor.from_elements %4528 : tensor<1xi64> + %4529 = stablehlo.dynamic_reshape %4527, %from_elements_1569 : (tensor, tensor<1xi64>) -> tensor + %dim_1570 = tensor.dim %4529, %c0 : tensor + %4530 = arith.index_cast %dim_1570 : index to i64 + %from_elements_1571 = tensor.from_elements %4530, %c1_i64 : tensor<2xi64> + %4531 = stablehlo.dynamic_reshape %4529, %from_elements_1571 : (tensor, tensor<2xi64>) -> tensor + %dim_1572 = tensor.dim %4531, %c0 : tensor + %4532 = arith.index_cast %dim_1572 : index to i64 + %from_elements_1573 = tensor.from_elements %c1_i64, %4532, %c4096_i64 : tensor<3xi64> + %4533 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1573, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1574 = tensor.dim %4533, %c1 : tensor<1x?x4096xi64> + %4534 = arith.index_cast %dim_1574 : index to i64 + %from_elements_1575 = tensor.from_elements %c1_i64, %4534, %c4096_i64, %c1_i64 : tensor<4xi64> + %4535 = stablehlo.dynamic_reshape %4533, %from_elements_1575 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4536 = stablehlo.dynamic_broadcast_in_dim %4531, %from_elements_1573, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1576 = tensor.dim %4536, %c1 : tensor<1x?x4096xi64> + %4537 = arith.index_cast %dim_1576 : index to i64 + %from_elements_1577 = tensor.from_elements %c1_i64, %4537, %c4096_i64, %c1_i64 : tensor<4xi64> + %4538 = stablehlo.dynamic_reshape %4536, %from_elements_1577 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4539 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1573, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1578 = tensor.dim %4539, %c1 : tensor<1x?x4096xi64> + %4540 = arith.index_cast %dim_1578 : index to i64 + %from_elements_1579 = tensor.from_elements %c1_i64, %4540, %c4096_i64, %c1_i64 : tensor<4xi64> + %4541 = stablehlo.dynamic_reshape %4539, %from_elements_1579 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4542 = stablehlo.concatenate %4535, %4538, %4541, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4543 = "stablehlo.gather"(%4465, %4542) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4544 = shape.shape_of %4543 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4545 = shape.num_elements %4544 : tensor<3xindex> -> index + %4546 = stablehlo.compute_reshape_shape %4545, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4547 = stablehlo.dynamic_reshape %4543, %4546 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4548 = stablehlo.dot %4547, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4549 = stablehlo.logistic %4548 : tensor + %4550 = shape.shape_of %4549 : tensor -> tensor<2xindex> + %4551 = shape.shape_of %4548 : tensor -> tensor<2xindex> + %4552 = shape.cstr_broadcastable %4550, %4551 : tensor<2xindex>, tensor<2xindex> + %4553 = shape.assuming %4552 -> (tensor) { + %19688 = shape.broadcast %4550, %4551 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4549, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4548, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4554 = shape.shape_of %4553 : tensor -> tensor<2xindex> + %4555 = shape.cstr_broadcastable %4554, %4551 : tensor<2xindex>, tensor<2xindex> + %4556 = shape.assuming %4555 -> (tensor) { + %19688 = shape.broadcast %4554, %4551 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4553, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4548, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4557 = stablehlo.dot %4556, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1580 = tensor.dim %4529, %c0 : tensor + %4558 = arith.index_cast %dim_1580 : index to i64 + %from_elements_1581 = tensor.from_elements %4558, %c1_i64 : tensor<2xi64> + %4559 = stablehlo.dynamic_reshape %4529, %from_elements_1581 : (tensor, tensor<2xi64>) -> tensor + %dim_1582 = tensor.dim %4526, %c0 : tensor + %4560 = arith.index_cast %dim_1582 : index to i64 + %from_elements_1583 = tensor.from_elements %4560, %c1_i64 : tensor<2xi64> + %4561 = stablehlo.dynamic_reshape %4526, %from_elements_1583 : (tensor, tensor<2xi64>) -> tensor + %4562 = stablehlo.concatenate %4559, %4561, dim = 1 : (tensor, tensor) -> tensor + %4563 = "stablehlo.gather"(%4494, %4562) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4564 = shape.shape_of %4557 : tensor -> tensor<2xindex> + %4565 = shape.shape_of %4563 : tensor -> tensor<2xindex> + %4566 = shape.cstr_broadcastable %4564, %4565 : tensor<2xindex>, tensor<2xindex> + %4567 = shape.assuming %4566 -> (tensor) { + %19688 = shape.broadcast %4564, %4565 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4557, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4563, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4568 = shape.shape_of %4567 : tensor -> tensor<2xindex> + %4569 = stablehlo.dynamic_broadcast_in_dim %4567, %4568, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4570 = stablehlo.dynamic_broadcast_in_dim %213, %4568, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4571 = stablehlo.multiply %4569, %4570 : tensor + %dim_1584 = tensor.dim %4531, %c0 : tensor + %4572 = arith.index_cast %dim_1584 : index to i64 + %dim_1585 = tensor.dim %4567, %c0 : tensor + %4573 = arith.index_cast %dim_1585 : index to i64 + %4574 = arith.maxsi %4572, %4573 : i64 + %4575 = arith.index_cast %4574 : i64 to index + %from_elements_1586 = tensor.from_elements %4575, %c4096 : tensor<2xindex> + %4576 = stablehlo.dynamic_broadcast_in_dim %4531, %from_elements_1586, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1587 = tensor.dim %4576, %c0 : tensor + %4577 = arith.index_cast %dim_1587 : index to i64 + %from_elements_1588 = tensor.from_elements %4577, %c4096_i64 : tensor<2xi64> + %4578 = stablehlo.real_dynamic_slice %4571, %c_22, %from_elements_1588, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1589 = tensor.from_elements %4577, %c4096_i64, %c1_i64 : tensor<3xi64> + %4579 = stablehlo.dynamic_reshape %4576, %from_elements_1589 : (tensor, tensor<3xi64>) -> tensor + %4580 = stablehlo.dynamic_iota %from_elements_1589, dim = 1 : (tensor<3xi64>) -> tensor + %4581 = stablehlo.concatenate %4579, %4580, dim = 2 : (tensor, tensor) -> tensor + %4582 = "stablehlo.scatter"(%4519, %4581, %4578) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4583 = stablehlo.slice %4454 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4584 = stablehlo.reshape %4583 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4585 = stablehlo.custom_call @byteir.non_zero(%4584) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1590 = tensor.dim %4585, %c0 : tensor + %4586 = arith.index_cast %dim_1590 : index to i64 + %from_elements_1591 = tensor.from_elements %4586, %c1_i64 : tensor<2xi64> + %4587 = stablehlo.real_dynamic_slice %4585, %c_22, %from_elements_1591, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1592 = tensor.dim %4587, %c0 : tensor + %4588 = arith.index_cast %dim_1592 : index to i64 + %from_elements_1593 = tensor.from_elements %4588 : tensor<1xi64> + %4589 = stablehlo.dynamic_reshape %4587, %from_elements_1593 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1594 = tensor.from_elements %4586, %c2_i64 : tensor<2xi64> + %4590 = stablehlo.real_dynamic_slice %4585, %c_24, %from_elements_1594, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1595 = tensor.dim %4590, %c0 : tensor + %4591 = arith.index_cast %dim_1595 : index to i64 + %from_elements_1596 = tensor.from_elements %4591 : tensor<1xi64> + %4592 = stablehlo.dynamic_reshape %4590, %from_elements_1596 : (tensor, tensor<1xi64>) -> tensor + %dim_1597 = tensor.dim %4592, %c0 : tensor + %4593 = arith.index_cast %dim_1597 : index to i64 + %from_elements_1598 = tensor.from_elements %4593, %c1_i64 : tensor<2xi64> + %4594 = stablehlo.dynamic_reshape %4592, %from_elements_1598 : (tensor, tensor<2xi64>) -> tensor + %dim_1599 = tensor.dim %4594, %c0 : tensor + %4595 = arith.index_cast %dim_1599 : index to i64 + %from_elements_1600 = tensor.from_elements %c1_i64, %4595, %c4096_i64 : tensor<3xi64> + %4596 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1600, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1601 = tensor.dim %4596, %c1 : tensor<1x?x4096xi64> + %4597 = arith.index_cast %dim_1601 : index to i64 + %from_elements_1602 = tensor.from_elements %c1_i64, %4597, %c4096_i64, %c1_i64 : tensor<4xi64> + %4598 = stablehlo.dynamic_reshape %4596, %from_elements_1602 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4599 = stablehlo.dynamic_broadcast_in_dim %4594, %from_elements_1600, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1603 = tensor.dim %4599, %c1 : tensor<1x?x4096xi64> + %4600 = arith.index_cast %dim_1603 : index to i64 + %from_elements_1604 = tensor.from_elements %c1_i64, %4600, %c4096_i64, %c1_i64 : tensor<4xi64> + %4601 = stablehlo.dynamic_reshape %4599, %from_elements_1604 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4602 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1600, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1605 = tensor.dim %4602, %c1 : tensor<1x?x4096xi64> + %4603 = arith.index_cast %dim_1605 : index to i64 + %from_elements_1606 = tensor.from_elements %c1_i64, %4603, %c4096_i64, %c1_i64 : tensor<4xi64> + %4604 = stablehlo.dynamic_reshape %4602, %from_elements_1606 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4605 = stablehlo.concatenate %4598, %4601, %4604, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4606 = "stablehlo.gather"(%4465, %4605) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4607 = shape.shape_of %4606 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4608 = shape.num_elements %4607 : tensor<3xindex> -> index + %4609 = stablehlo.compute_reshape_shape %4608, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4610 = stablehlo.dynamic_reshape %4606, %4609 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4611 = stablehlo.dot %4610, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4612 = stablehlo.logistic %4611 : tensor + %4613 = shape.shape_of %4612 : tensor -> tensor<2xindex> + %4614 = shape.shape_of %4611 : tensor -> tensor<2xindex> + %4615 = shape.cstr_broadcastable %4613, %4614 : tensor<2xindex>, tensor<2xindex> + %4616 = shape.assuming %4615 -> (tensor) { + %19688 = shape.broadcast %4613, %4614 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4612, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4611, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4617 = shape.shape_of %4616 : tensor -> tensor<2xindex> + %4618 = shape.cstr_broadcastable %4617, %4614 : tensor<2xindex>, tensor<2xindex> + %4619 = shape.assuming %4618 -> (tensor) { + %19688 = shape.broadcast %4617, %4614 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4616, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4611, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4620 = stablehlo.dot %4619, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1607 = tensor.dim %4592, %c0 : tensor + %4621 = arith.index_cast %dim_1607 : index to i64 + %from_elements_1608 = tensor.from_elements %4621, %c1_i64 : tensor<2xi64> + %4622 = stablehlo.dynamic_reshape %4592, %from_elements_1608 : (tensor, tensor<2xi64>) -> tensor + %dim_1609 = tensor.dim %4589, %c0 : tensor + %4623 = arith.index_cast %dim_1609 : index to i64 + %from_elements_1610 = tensor.from_elements %4623, %c1_i64 : tensor<2xi64> + %4624 = stablehlo.dynamic_reshape %4589, %from_elements_1610 : (tensor, tensor<2xi64>) -> tensor + %4625 = stablehlo.concatenate %4622, %4624, dim = 1 : (tensor, tensor) -> tensor + %4626 = "stablehlo.gather"(%4494, %4625) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4627 = shape.shape_of %4620 : tensor -> tensor<2xindex> + %4628 = shape.shape_of %4626 : tensor -> tensor<2xindex> + %4629 = shape.cstr_broadcastable %4627, %4628 : tensor<2xindex>, tensor<2xindex> + %4630 = shape.assuming %4629 -> (tensor) { + %19688 = shape.broadcast %4627, %4628 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4620, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4626, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4631 = shape.shape_of %4630 : tensor -> tensor<2xindex> + %4632 = stablehlo.dynamic_broadcast_in_dim %4630, %4631, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4633 = stablehlo.dynamic_broadcast_in_dim %213, %4631, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4634 = stablehlo.multiply %4632, %4633 : tensor + %dim_1611 = tensor.dim %4594, %c0 : tensor + %4635 = arith.index_cast %dim_1611 : index to i64 + %dim_1612 = tensor.dim %4630, %c0 : tensor + %4636 = arith.index_cast %dim_1612 : index to i64 + %4637 = arith.maxsi %4635, %4636 : i64 + %4638 = arith.index_cast %4637 : i64 to index + %from_elements_1613 = tensor.from_elements %4638, %c4096 : tensor<2xindex> + %4639 = stablehlo.dynamic_broadcast_in_dim %4594, %from_elements_1613, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1614 = tensor.dim %4639, %c0 : tensor + %4640 = arith.index_cast %dim_1614 : index to i64 + %from_elements_1615 = tensor.from_elements %4640, %c4096_i64 : tensor<2xi64> + %4641 = stablehlo.real_dynamic_slice %4634, %c_22, %from_elements_1615, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1616 = tensor.from_elements %4640, %c4096_i64, %c1_i64 : tensor<3xi64> + %4642 = stablehlo.dynamic_reshape %4639, %from_elements_1616 : (tensor, tensor<3xi64>) -> tensor + %4643 = stablehlo.dynamic_iota %from_elements_1616, dim = 1 : (tensor<3xi64>) -> tensor + %4644 = stablehlo.concatenate %4642, %4643, dim = 2 : (tensor, tensor) -> tensor + %4645 = "stablehlo.scatter"(%4582, %4644, %4641) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4646 = stablehlo.slice %4454 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4647 = stablehlo.reshape %4646 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4648 = stablehlo.custom_call @byteir.non_zero(%4647) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1617 = tensor.dim %4648, %c0 : tensor + %4649 = arith.index_cast %dim_1617 : index to i64 + %from_elements_1618 = tensor.from_elements %4649, %c1_i64 : tensor<2xi64> + %4650 = stablehlo.real_dynamic_slice %4648, %c_22, %from_elements_1618, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1619 = tensor.dim %4650, %c0 : tensor + %4651 = arith.index_cast %dim_1619 : index to i64 + %from_elements_1620 = tensor.from_elements %4651 : tensor<1xi64> + %4652 = stablehlo.dynamic_reshape %4650, %from_elements_1620 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1621 = tensor.from_elements %4649, %c2_i64 : tensor<2xi64> + %4653 = stablehlo.real_dynamic_slice %4648, %c_24, %from_elements_1621, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1622 = tensor.dim %4653, %c0 : tensor + %4654 = arith.index_cast %dim_1622 : index to i64 + %from_elements_1623 = tensor.from_elements %4654 : tensor<1xi64> + %4655 = stablehlo.dynamic_reshape %4653, %from_elements_1623 : (tensor, tensor<1xi64>) -> tensor + %dim_1624 = tensor.dim %4655, %c0 : tensor + %4656 = arith.index_cast %dim_1624 : index to i64 + %from_elements_1625 = tensor.from_elements %4656, %c1_i64 : tensor<2xi64> + %4657 = stablehlo.dynamic_reshape %4655, %from_elements_1625 : (tensor, tensor<2xi64>) -> tensor + %dim_1626 = tensor.dim %4657, %c0 : tensor + %4658 = arith.index_cast %dim_1626 : index to i64 + %from_elements_1627 = tensor.from_elements %c1_i64, %4658, %c4096_i64 : tensor<3xi64> + %4659 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1627, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1628 = tensor.dim %4659, %c1 : tensor<1x?x4096xi64> + %4660 = arith.index_cast %dim_1628 : index to i64 + %from_elements_1629 = tensor.from_elements %c1_i64, %4660, %c4096_i64, %c1_i64 : tensor<4xi64> + %4661 = stablehlo.dynamic_reshape %4659, %from_elements_1629 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4662 = stablehlo.dynamic_broadcast_in_dim %4657, %from_elements_1627, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1630 = tensor.dim %4662, %c1 : tensor<1x?x4096xi64> + %4663 = arith.index_cast %dim_1630 : index to i64 + %from_elements_1631 = tensor.from_elements %c1_i64, %4663, %c4096_i64, %c1_i64 : tensor<4xi64> + %4664 = stablehlo.dynamic_reshape %4662, %from_elements_1631 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4665 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1627, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1632 = tensor.dim %4665, %c1 : tensor<1x?x4096xi64> + %4666 = arith.index_cast %dim_1632 : index to i64 + %from_elements_1633 = tensor.from_elements %c1_i64, %4666, %c4096_i64, %c1_i64 : tensor<4xi64> + %4667 = stablehlo.dynamic_reshape %4665, %from_elements_1633 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4668 = stablehlo.concatenate %4661, %4664, %4667, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4669 = "stablehlo.gather"(%4465, %4668) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4670 = shape.shape_of %4669 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4671 = shape.num_elements %4670 : tensor<3xindex> -> index + %4672 = stablehlo.compute_reshape_shape %4671, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4673 = stablehlo.dynamic_reshape %4669, %4672 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4674 = stablehlo.dot %4673, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4675 = stablehlo.logistic %4674 : tensor + %4676 = shape.shape_of %4675 : tensor -> tensor<2xindex> + %4677 = shape.shape_of %4674 : tensor -> tensor<2xindex> + %4678 = shape.cstr_broadcastable %4676, %4677 : tensor<2xindex>, tensor<2xindex> + %4679 = shape.assuming %4678 -> (tensor) { + %19688 = shape.broadcast %4676, %4677 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4675, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4674, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4680 = shape.shape_of %4679 : tensor -> tensor<2xindex> + %4681 = shape.cstr_broadcastable %4680, %4677 : tensor<2xindex>, tensor<2xindex> + %4682 = shape.assuming %4681 -> (tensor) { + %19688 = shape.broadcast %4680, %4677 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4679, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4674, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4683 = stablehlo.dot %4682, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1634 = tensor.dim %4655, %c0 : tensor + %4684 = arith.index_cast %dim_1634 : index to i64 + %from_elements_1635 = tensor.from_elements %4684, %c1_i64 : tensor<2xi64> + %4685 = stablehlo.dynamic_reshape %4655, %from_elements_1635 : (tensor, tensor<2xi64>) -> tensor + %dim_1636 = tensor.dim %4652, %c0 : tensor + %4686 = arith.index_cast %dim_1636 : index to i64 + %from_elements_1637 = tensor.from_elements %4686, %c1_i64 : tensor<2xi64> + %4687 = stablehlo.dynamic_reshape %4652, %from_elements_1637 : (tensor, tensor<2xi64>) -> tensor + %4688 = stablehlo.concatenate %4685, %4687, dim = 1 : (tensor, tensor) -> tensor + %4689 = "stablehlo.gather"(%4494, %4688) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4690 = shape.shape_of %4683 : tensor -> tensor<2xindex> + %4691 = shape.shape_of %4689 : tensor -> tensor<2xindex> + %4692 = shape.cstr_broadcastable %4690, %4691 : tensor<2xindex>, tensor<2xindex> + %4693 = shape.assuming %4692 -> (tensor) { + %19688 = shape.broadcast %4690, %4691 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4683, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4689, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4694 = shape.shape_of %4693 : tensor -> tensor<2xindex> + %4695 = stablehlo.dynamic_broadcast_in_dim %4693, %4694, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4696 = stablehlo.dynamic_broadcast_in_dim %213, %4694, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4697 = stablehlo.multiply %4695, %4696 : tensor + %dim_1638 = tensor.dim %4657, %c0 : tensor + %4698 = arith.index_cast %dim_1638 : index to i64 + %dim_1639 = tensor.dim %4693, %c0 : tensor + %4699 = arith.index_cast %dim_1639 : index to i64 + %4700 = arith.maxsi %4698, %4699 : i64 + %4701 = arith.index_cast %4700 : i64 to index + %from_elements_1640 = tensor.from_elements %4701, %c4096 : tensor<2xindex> + %4702 = stablehlo.dynamic_broadcast_in_dim %4657, %from_elements_1640, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1641 = tensor.dim %4702, %c0 : tensor + %4703 = arith.index_cast %dim_1641 : index to i64 + %from_elements_1642 = tensor.from_elements %4703, %c4096_i64 : tensor<2xi64> + %4704 = stablehlo.real_dynamic_slice %4697, %c_22, %from_elements_1642, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1643 = tensor.from_elements %4703, %c4096_i64, %c1_i64 : tensor<3xi64> + %4705 = stablehlo.dynamic_reshape %4702, %from_elements_1643 : (tensor, tensor<3xi64>) -> tensor + %4706 = stablehlo.dynamic_iota %from_elements_1643, dim = 1 : (tensor<3xi64>) -> tensor + %4707 = stablehlo.concatenate %4705, %4706, dim = 2 : (tensor, tensor) -> tensor + %4708 = "stablehlo.scatter"(%4645, %4707, %4704) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4709 = stablehlo.slice %4454 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4710 = stablehlo.reshape %4709 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4711 = stablehlo.custom_call @byteir.non_zero(%4710) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1644 = tensor.dim %4711, %c0 : tensor + %4712 = arith.index_cast %dim_1644 : index to i64 + %from_elements_1645 = tensor.from_elements %4712, %c1_i64 : tensor<2xi64> + %4713 = stablehlo.real_dynamic_slice %4711, %c_22, %from_elements_1645, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1646 = tensor.dim %4713, %c0 : tensor + %4714 = arith.index_cast %dim_1646 : index to i64 + %from_elements_1647 = tensor.from_elements %4714 : tensor<1xi64> + %4715 = stablehlo.dynamic_reshape %4713, %from_elements_1647 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1648 = tensor.from_elements %4712, %c2_i64 : tensor<2xi64> + %4716 = stablehlo.real_dynamic_slice %4711, %c_24, %from_elements_1648, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1649 = tensor.dim %4716, %c0 : tensor + %4717 = arith.index_cast %dim_1649 : index to i64 + %from_elements_1650 = tensor.from_elements %4717 : tensor<1xi64> + %4718 = stablehlo.dynamic_reshape %4716, %from_elements_1650 : (tensor, tensor<1xi64>) -> tensor + %dim_1651 = tensor.dim %4718, %c0 : tensor + %4719 = arith.index_cast %dim_1651 : index to i64 + %from_elements_1652 = tensor.from_elements %4719, %c1_i64 : tensor<2xi64> + %4720 = stablehlo.dynamic_reshape %4718, %from_elements_1652 : (tensor, tensor<2xi64>) -> tensor + %dim_1653 = tensor.dim %4720, %c0 : tensor + %4721 = arith.index_cast %dim_1653 : index to i64 + %from_elements_1654 = tensor.from_elements %c1_i64, %4721, %c4096_i64 : tensor<3xi64> + %4722 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1654, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1655 = tensor.dim %4722, %c1 : tensor<1x?x4096xi64> + %4723 = arith.index_cast %dim_1655 : index to i64 + %from_elements_1656 = tensor.from_elements %c1_i64, %4723, %c4096_i64, %c1_i64 : tensor<4xi64> + %4724 = stablehlo.dynamic_reshape %4722, %from_elements_1656 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4725 = stablehlo.dynamic_broadcast_in_dim %4720, %from_elements_1654, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1657 = tensor.dim %4725, %c1 : tensor<1x?x4096xi64> + %4726 = arith.index_cast %dim_1657 : index to i64 + %from_elements_1658 = tensor.from_elements %c1_i64, %4726, %c4096_i64, %c1_i64 : tensor<4xi64> + %4727 = stablehlo.dynamic_reshape %4725, %from_elements_1658 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4728 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1654, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1659 = tensor.dim %4728, %c1 : tensor<1x?x4096xi64> + %4729 = arith.index_cast %dim_1659 : index to i64 + %from_elements_1660 = tensor.from_elements %c1_i64, %4729, %c4096_i64, %c1_i64 : tensor<4xi64> + %4730 = stablehlo.dynamic_reshape %4728, %from_elements_1660 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4731 = stablehlo.concatenate %4724, %4727, %4730, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4732 = "stablehlo.gather"(%4465, %4731) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4733 = shape.shape_of %4732 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4734 = shape.num_elements %4733 : tensor<3xindex> -> index + %4735 = stablehlo.compute_reshape_shape %4734, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4736 = stablehlo.dynamic_reshape %4732, %4735 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4737 = stablehlo.dot %4736, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4738 = stablehlo.logistic %4737 : tensor + %4739 = shape.shape_of %4738 : tensor -> tensor<2xindex> + %4740 = shape.shape_of %4737 : tensor -> tensor<2xindex> + %4741 = shape.cstr_broadcastable %4739, %4740 : tensor<2xindex>, tensor<2xindex> + %4742 = shape.assuming %4741 -> (tensor) { + %19688 = shape.broadcast %4739, %4740 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4738, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4737, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4743 = shape.shape_of %4742 : tensor -> tensor<2xindex> + %4744 = shape.cstr_broadcastable %4743, %4740 : tensor<2xindex>, tensor<2xindex> + %4745 = shape.assuming %4744 -> (tensor) { + %19688 = shape.broadcast %4743, %4740 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4742, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4737, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4746 = stablehlo.dot %4745, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1661 = tensor.dim %4718, %c0 : tensor + %4747 = arith.index_cast %dim_1661 : index to i64 + %from_elements_1662 = tensor.from_elements %4747, %c1_i64 : tensor<2xi64> + %4748 = stablehlo.dynamic_reshape %4718, %from_elements_1662 : (tensor, tensor<2xi64>) -> tensor + %dim_1663 = tensor.dim %4715, %c0 : tensor + %4749 = arith.index_cast %dim_1663 : index to i64 + %from_elements_1664 = tensor.from_elements %4749, %c1_i64 : tensor<2xi64> + %4750 = stablehlo.dynamic_reshape %4715, %from_elements_1664 : (tensor, tensor<2xi64>) -> tensor + %4751 = stablehlo.concatenate %4748, %4750, dim = 1 : (tensor, tensor) -> tensor + %4752 = "stablehlo.gather"(%4494, %4751) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4753 = shape.shape_of %4746 : tensor -> tensor<2xindex> + %4754 = shape.shape_of %4752 : tensor -> tensor<2xindex> + %4755 = shape.cstr_broadcastable %4753, %4754 : tensor<2xindex>, tensor<2xindex> + %4756 = shape.assuming %4755 -> (tensor) { + %19688 = shape.broadcast %4753, %4754 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4746, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4752, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4757 = shape.shape_of %4756 : tensor -> tensor<2xindex> + %4758 = stablehlo.dynamic_broadcast_in_dim %4756, %4757, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4759 = stablehlo.dynamic_broadcast_in_dim %213, %4757, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4760 = stablehlo.multiply %4758, %4759 : tensor + %dim_1665 = tensor.dim %4720, %c0 : tensor + %4761 = arith.index_cast %dim_1665 : index to i64 + %dim_1666 = tensor.dim %4756, %c0 : tensor + %4762 = arith.index_cast %dim_1666 : index to i64 + %4763 = arith.maxsi %4761, %4762 : i64 + %4764 = arith.index_cast %4763 : i64 to index + %from_elements_1667 = tensor.from_elements %4764, %c4096 : tensor<2xindex> + %4765 = stablehlo.dynamic_broadcast_in_dim %4720, %from_elements_1667, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1668 = tensor.dim %4765, %c0 : tensor + %4766 = arith.index_cast %dim_1668 : index to i64 + %from_elements_1669 = tensor.from_elements %4766, %c4096_i64 : tensor<2xi64> + %4767 = stablehlo.real_dynamic_slice %4760, %c_22, %from_elements_1669, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1670 = tensor.from_elements %4766, %c4096_i64, %c1_i64 : tensor<3xi64> + %4768 = stablehlo.dynamic_reshape %4765, %from_elements_1670 : (tensor, tensor<3xi64>) -> tensor + %4769 = stablehlo.dynamic_iota %from_elements_1670, dim = 1 : (tensor<3xi64>) -> tensor + %4770 = stablehlo.concatenate %4768, %4769, dim = 2 : (tensor, tensor) -> tensor + %4771 = "stablehlo.scatter"(%4708, %4770, %4767) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4772 = stablehlo.slice %4454 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4773 = stablehlo.reshape %4772 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4774 = stablehlo.custom_call @byteir.non_zero(%4773) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1671 = tensor.dim %4774, %c0 : tensor + %4775 = arith.index_cast %dim_1671 : index to i64 + %from_elements_1672 = tensor.from_elements %4775, %c1_i64 : tensor<2xi64> + %4776 = stablehlo.real_dynamic_slice %4774, %c_22, %from_elements_1672, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1673 = tensor.dim %4776, %c0 : tensor + %4777 = arith.index_cast %dim_1673 : index to i64 + %from_elements_1674 = tensor.from_elements %4777 : tensor<1xi64> + %4778 = stablehlo.dynamic_reshape %4776, %from_elements_1674 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1675 = tensor.from_elements %4775, %c2_i64 : tensor<2xi64> + %4779 = stablehlo.real_dynamic_slice %4774, %c_24, %from_elements_1675, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1676 = tensor.dim %4779, %c0 : tensor + %4780 = arith.index_cast %dim_1676 : index to i64 + %from_elements_1677 = tensor.from_elements %4780 : tensor<1xi64> + %4781 = stablehlo.dynamic_reshape %4779, %from_elements_1677 : (tensor, tensor<1xi64>) -> tensor + %dim_1678 = tensor.dim %4781, %c0 : tensor + %4782 = arith.index_cast %dim_1678 : index to i64 + %from_elements_1679 = tensor.from_elements %4782, %c1_i64 : tensor<2xi64> + %4783 = stablehlo.dynamic_reshape %4781, %from_elements_1679 : (tensor, tensor<2xi64>) -> tensor + %dim_1680 = tensor.dim %4783, %c0 : tensor + %4784 = arith.index_cast %dim_1680 : index to i64 + %from_elements_1681 = tensor.from_elements %c1_i64, %4784, %c4096_i64 : tensor<3xi64> + %4785 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1681, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1682 = tensor.dim %4785, %c1 : tensor<1x?x4096xi64> + %4786 = arith.index_cast %dim_1682 : index to i64 + %from_elements_1683 = tensor.from_elements %c1_i64, %4786, %c4096_i64, %c1_i64 : tensor<4xi64> + %4787 = stablehlo.dynamic_reshape %4785, %from_elements_1683 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4788 = stablehlo.dynamic_broadcast_in_dim %4783, %from_elements_1681, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1684 = tensor.dim %4788, %c1 : tensor<1x?x4096xi64> + %4789 = arith.index_cast %dim_1684 : index to i64 + %from_elements_1685 = tensor.from_elements %c1_i64, %4789, %c4096_i64, %c1_i64 : tensor<4xi64> + %4790 = stablehlo.dynamic_reshape %4788, %from_elements_1685 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4791 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1681, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1686 = tensor.dim %4791, %c1 : tensor<1x?x4096xi64> + %4792 = arith.index_cast %dim_1686 : index to i64 + %from_elements_1687 = tensor.from_elements %c1_i64, %4792, %c4096_i64, %c1_i64 : tensor<4xi64> + %4793 = stablehlo.dynamic_reshape %4791, %from_elements_1687 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4794 = stablehlo.concatenate %4787, %4790, %4793, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4795 = "stablehlo.gather"(%4465, %4794) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4796 = shape.shape_of %4795 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4797 = shape.num_elements %4796 : tensor<3xindex> -> index + %4798 = stablehlo.compute_reshape_shape %4797, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4799 = stablehlo.dynamic_reshape %4795, %4798 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4800 = stablehlo.dot %4799, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4801 = stablehlo.logistic %4800 : tensor + %4802 = shape.shape_of %4801 : tensor -> tensor<2xindex> + %4803 = shape.shape_of %4800 : tensor -> tensor<2xindex> + %4804 = shape.cstr_broadcastable %4802, %4803 : tensor<2xindex>, tensor<2xindex> + %4805 = shape.assuming %4804 -> (tensor) { + %19688 = shape.broadcast %4802, %4803 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4801, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4800, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4806 = shape.shape_of %4805 : tensor -> tensor<2xindex> + %4807 = shape.cstr_broadcastable %4806, %4803 : tensor<2xindex>, tensor<2xindex> + %4808 = shape.assuming %4807 -> (tensor) { + %19688 = shape.broadcast %4806, %4803 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4805, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4800, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4809 = stablehlo.dot %4808, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1688 = tensor.dim %4781, %c0 : tensor + %4810 = arith.index_cast %dim_1688 : index to i64 + %from_elements_1689 = tensor.from_elements %4810, %c1_i64 : tensor<2xi64> + %4811 = stablehlo.dynamic_reshape %4781, %from_elements_1689 : (tensor, tensor<2xi64>) -> tensor + %dim_1690 = tensor.dim %4778, %c0 : tensor + %4812 = arith.index_cast %dim_1690 : index to i64 + %from_elements_1691 = tensor.from_elements %4812, %c1_i64 : tensor<2xi64> + %4813 = stablehlo.dynamic_reshape %4778, %from_elements_1691 : (tensor, tensor<2xi64>) -> tensor + %4814 = stablehlo.concatenate %4811, %4813, dim = 1 : (tensor, tensor) -> tensor + %4815 = "stablehlo.gather"(%4494, %4814) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4816 = shape.shape_of %4809 : tensor -> tensor<2xindex> + %4817 = shape.shape_of %4815 : tensor -> tensor<2xindex> + %4818 = shape.cstr_broadcastable %4816, %4817 : tensor<2xindex>, tensor<2xindex> + %4819 = shape.assuming %4818 -> (tensor) { + %19688 = shape.broadcast %4816, %4817 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4809, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4815, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4820 = shape.shape_of %4819 : tensor -> tensor<2xindex> + %4821 = stablehlo.dynamic_broadcast_in_dim %4819, %4820, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4822 = stablehlo.dynamic_broadcast_in_dim %213, %4820, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4823 = stablehlo.multiply %4821, %4822 : tensor + %dim_1692 = tensor.dim %4783, %c0 : tensor + %4824 = arith.index_cast %dim_1692 : index to i64 + %dim_1693 = tensor.dim %4819, %c0 : tensor + %4825 = arith.index_cast %dim_1693 : index to i64 + %4826 = arith.maxsi %4824, %4825 : i64 + %4827 = arith.index_cast %4826 : i64 to index + %from_elements_1694 = tensor.from_elements %4827, %c4096 : tensor<2xindex> + %4828 = stablehlo.dynamic_broadcast_in_dim %4783, %from_elements_1694, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1695 = tensor.dim %4828, %c0 : tensor + %4829 = arith.index_cast %dim_1695 : index to i64 + %from_elements_1696 = tensor.from_elements %4829, %c4096_i64 : tensor<2xi64> + %4830 = stablehlo.real_dynamic_slice %4823, %c_22, %from_elements_1696, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1697 = tensor.from_elements %4829, %c4096_i64, %c1_i64 : tensor<3xi64> + %4831 = stablehlo.dynamic_reshape %4828, %from_elements_1697 : (tensor, tensor<3xi64>) -> tensor + %4832 = stablehlo.dynamic_iota %from_elements_1697, dim = 1 : (tensor<3xi64>) -> tensor + %4833 = stablehlo.concatenate %4831, %4832, dim = 2 : (tensor, tensor) -> tensor + %4834 = "stablehlo.scatter"(%4771, %4833, %4830) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4835 = stablehlo.slice %4454 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4836 = stablehlo.reshape %4835 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4837 = stablehlo.custom_call @byteir.non_zero(%4836) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1698 = tensor.dim %4837, %c0 : tensor + %4838 = arith.index_cast %dim_1698 : index to i64 + %from_elements_1699 = tensor.from_elements %4838, %c1_i64 : tensor<2xi64> + %4839 = stablehlo.real_dynamic_slice %4837, %c_22, %from_elements_1699, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1700 = tensor.dim %4839, %c0 : tensor + %4840 = arith.index_cast %dim_1700 : index to i64 + %from_elements_1701 = tensor.from_elements %4840 : tensor<1xi64> + %4841 = stablehlo.dynamic_reshape %4839, %from_elements_1701 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1702 = tensor.from_elements %4838, %c2_i64 : tensor<2xi64> + %4842 = stablehlo.real_dynamic_slice %4837, %c_24, %from_elements_1702, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1703 = tensor.dim %4842, %c0 : tensor + %4843 = arith.index_cast %dim_1703 : index to i64 + %from_elements_1704 = tensor.from_elements %4843 : tensor<1xi64> + %4844 = stablehlo.dynamic_reshape %4842, %from_elements_1704 : (tensor, tensor<1xi64>) -> tensor + %dim_1705 = tensor.dim %4844, %c0 : tensor + %4845 = arith.index_cast %dim_1705 : index to i64 + %from_elements_1706 = tensor.from_elements %4845, %c1_i64 : tensor<2xi64> + %4846 = stablehlo.dynamic_reshape %4844, %from_elements_1706 : (tensor, tensor<2xi64>) -> tensor + %dim_1707 = tensor.dim %4846, %c0 : tensor + %4847 = arith.index_cast %dim_1707 : index to i64 + %from_elements_1708 = tensor.from_elements %c1_i64, %4847, %c4096_i64 : tensor<3xi64> + %4848 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1708, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1709 = tensor.dim %4848, %c1 : tensor<1x?x4096xi64> + %4849 = arith.index_cast %dim_1709 : index to i64 + %from_elements_1710 = tensor.from_elements %c1_i64, %4849, %c4096_i64, %c1_i64 : tensor<4xi64> + %4850 = stablehlo.dynamic_reshape %4848, %from_elements_1710 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4851 = stablehlo.dynamic_broadcast_in_dim %4846, %from_elements_1708, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1711 = tensor.dim %4851, %c1 : tensor<1x?x4096xi64> + %4852 = arith.index_cast %dim_1711 : index to i64 + %from_elements_1712 = tensor.from_elements %c1_i64, %4852, %c4096_i64, %c1_i64 : tensor<4xi64> + %4853 = stablehlo.dynamic_reshape %4851, %from_elements_1712 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4854 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1708, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1713 = tensor.dim %4854, %c1 : tensor<1x?x4096xi64> + %4855 = arith.index_cast %dim_1713 : index to i64 + %from_elements_1714 = tensor.from_elements %c1_i64, %4855, %c4096_i64, %c1_i64 : tensor<4xi64> + %4856 = stablehlo.dynamic_reshape %4854, %from_elements_1714 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4857 = stablehlo.concatenate %4850, %4853, %4856, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4858 = "stablehlo.gather"(%4465, %4857) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4859 = shape.shape_of %4858 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4860 = shape.num_elements %4859 : tensor<3xindex> -> index + %4861 = stablehlo.compute_reshape_shape %4860, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4862 = stablehlo.dynamic_reshape %4858, %4861 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4863 = stablehlo.dot %4862, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4864 = stablehlo.logistic %4863 : tensor + %4865 = shape.shape_of %4864 : tensor -> tensor<2xindex> + %4866 = shape.shape_of %4863 : tensor -> tensor<2xindex> + %4867 = shape.cstr_broadcastable %4865, %4866 : tensor<2xindex>, tensor<2xindex> + %4868 = shape.assuming %4867 -> (tensor) { + %19688 = shape.broadcast %4865, %4866 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4864, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4863, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4869 = shape.shape_of %4868 : tensor -> tensor<2xindex> + %4870 = shape.cstr_broadcastable %4869, %4866 : tensor<2xindex>, tensor<2xindex> + %4871 = shape.assuming %4870 -> (tensor) { + %19688 = shape.broadcast %4869, %4866 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4868, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4863, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4872 = stablehlo.dot %4871, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1715 = tensor.dim %4844, %c0 : tensor + %4873 = arith.index_cast %dim_1715 : index to i64 + %from_elements_1716 = tensor.from_elements %4873, %c1_i64 : tensor<2xi64> + %4874 = stablehlo.dynamic_reshape %4844, %from_elements_1716 : (tensor, tensor<2xi64>) -> tensor + %dim_1717 = tensor.dim %4841, %c0 : tensor + %4875 = arith.index_cast %dim_1717 : index to i64 + %from_elements_1718 = tensor.from_elements %4875, %c1_i64 : tensor<2xi64> + %4876 = stablehlo.dynamic_reshape %4841, %from_elements_1718 : (tensor, tensor<2xi64>) -> tensor + %4877 = stablehlo.concatenate %4874, %4876, dim = 1 : (tensor, tensor) -> tensor + %4878 = "stablehlo.gather"(%4494, %4877) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4879 = shape.shape_of %4872 : tensor -> tensor<2xindex> + %4880 = shape.shape_of %4878 : tensor -> tensor<2xindex> + %4881 = shape.cstr_broadcastable %4879, %4880 : tensor<2xindex>, tensor<2xindex> + %4882 = shape.assuming %4881 -> (tensor) { + %19688 = shape.broadcast %4879, %4880 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4872, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4878, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4883 = shape.shape_of %4882 : tensor -> tensor<2xindex> + %4884 = stablehlo.dynamic_broadcast_in_dim %4882, %4883, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4885 = stablehlo.dynamic_broadcast_in_dim %213, %4883, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4886 = stablehlo.multiply %4884, %4885 : tensor + %dim_1719 = tensor.dim %4846, %c0 : tensor + %4887 = arith.index_cast %dim_1719 : index to i64 + %dim_1720 = tensor.dim %4882, %c0 : tensor + %4888 = arith.index_cast %dim_1720 : index to i64 + %4889 = arith.maxsi %4887, %4888 : i64 + %4890 = arith.index_cast %4889 : i64 to index + %from_elements_1721 = tensor.from_elements %4890, %c4096 : tensor<2xindex> + %4891 = stablehlo.dynamic_broadcast_in_dim %4846, %from_elements_1721, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1722 = tensor.dim %4891, %c0 : tensor + %4892 = arith.index_cast %dim_1722 : index to i64 + %from_elements_1723 = tensor.from_elements %4892, %c4096_i64 : tensor<2xi64> + %4893 = stablehlo.real_dynamic_slice %4886, %c_22, %from_elements_1723, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1724 = tensor.from_elements %4892, %c4096_i64, %c1_i64 : tensor<3xi64> + %4894 = stablehlo.dynamic_reshape %4891, %from_elements_1724 : (tensor, tensor<3xi64>) -> tensor + %4895 = stablehlo.dynamic_iota %from_elements_1724, dim = 1 : (tensor<3xi64>) -> tensor + %4896 = stablehlo.concatenate %4894, %4895, dim = 2 : (tensor, tensor) -> tensor + %4897 = "stablehlo.scatter"(%4834, %4896, %4893) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4898 = stablehlo.slice %4454 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %4899 = stablehlo.reshape %4898 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %4900 = stablehlo.custom_call @byteir.non_zero(%4899) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1725 = tensor.dim %4900, %c0 : tensor + %4901 = arith.index_cast %dim_1725 : index to i64 + %from_elements_1726 = tensor.from_elements %4901, %c1_i64 : tensor<2xi64> + %4902 = stablehlo.real_dynamic_slice %4900, %c_22, %from_elements_1726, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1727 = tensor.dim %4902, %c0 : tensor + %4903 = arith.index_cast %dim_1727 : index to i64 + %from_elements_1728 = tensor.from_elements %4903 : tensor<1xi64> + %4904 = stablehlo.dynamic_reshape %4902, %from_elements_1728 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1729 = tensor.from_elements %4901, %c2_i64 : tensor<2xi64> + %4905 = stablehlo.real_dynamic_slice %4900, %c_24, %from_elements_1729, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1730 = tensor.dim %4905, %c0 : tensor + %4906 = arith.index_cast %dim_1730 : index to i64 + %from_elements_1731 = tensor.from_elements %4906 : tensor<1xi64> + %4907 = stablehlo.dynamic_reshape %4905, %from_elements_1731 : (tensor, tensor<1xi64>) -> tensor + %dim_1732 = tensor.dim %4907, %c0 : tensor + %4908 = arith.index_cast %dim_1732 : index to i64 + %from_elements_1733 = tensor.from_elements %4908, %c1_i64 : tensor<2xi64> + %4909 = stablehlo.dynamic_reshape %4907, %from_elements_1733 : (tensor, tensor<2xi64>) -> tensor + %dim_1734 = tensor.dim %4909, %c0 : tensor + %4910 = arith.index_cast %dim_1734 : index to i64 + %from_elements_1735 = tensor.from_elements %c1_i64, %4910, %c4096_i64 : tensor<3xi64> + %4911 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1735, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1736 = tensor.dim %4911, %c1 : tensor<1x?x4096xi64> + %4912 = arith.index_cast %dim_1736 : index to i64 + %from_elements_1737 = tensor.from_elements %c1_i64, %4912, %c4096_i64, %c1_i64 : tensor<4xi64> + %4913 = stablehlo.dynamic_reshape %4911, %from_elements_1737 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4914 = stablehlo.dynamic_broadcast_in_dim %4909, %from_elements_1735, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1738 = tensor.dim %4914, %c1 : tensor<1x?x4096xi64> + %4915 = arith.index_cast %dim_1738 : index to i64 + %from_elements_1739 = tensor.from_elements %c1_i64, %4915, %c4096_i64, %c1_i64 : tensor<4xi64> + %4916 = stablehlo.dynamic_reshape %4914, %from_elements_1739 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4917 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1735, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1740 = tensor.dim %4917, %c1 : tensor<1x?x4096xi64> + %4918 = arith.index_cast %dim_1740 : index to i64 + %from_elements_1741 = tensor.from_elements %c1_i64, %4918, %c4096_i64, %c1_i64 : tensor<4xi64> + %4919 = stablehlo.dynamic_reshape %4917, %from_elements_1741 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %4920 = stablehlo.concatenate %4913, %4916, %4919, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %4921 = "stablehlo.gather"(%4465, %4920) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %4922 = shape.shape_of %4921 : tensor<1x?x4096xf32> -> tensor<3xindex> + %4923 = shape.num_elements %4922 : tensor<3xindex> -> index + %4924 = stablehlo.compute_reshape_shape %4923, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %4925 = stablehlo.dynamic_reshape %4921, %4924 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %4926 = stablehlo.dot %4925, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %4927 = stablehlo.logistic %4926 : tensor + %4928 = shape.shape_of %4927 : tensor -> tensor<2xindex> + %4929 = shape.shape_of %4926 : tensor -> tensor<2xindex> + %4930 = shape.cstr_broadcastable %4928, %4929 : tensor<2xindex>, tensor<2xindex> + %4931 = shape.assuming %4930 -> (tensor) { + %19688 = shape.broadcast %4928, %4929 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4927, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4926, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4932 = shape.shape_of %4931 : tensor -> tensor<2xindex> + %4933 = shape.cstr_broadcastable %4932, %4929 : tensor<2xindex>, tensor<2xindex> + %4934 = shape.assuming %4933 -> (tensor) { + %19688 = shape.broadcast %4932, %4929 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4931, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4926, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4935 = stablehlo.dot %4934, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1742 = tensor.dim %4907, %c0 : tensor + %4936 = arith.index_cast %dim_1742 : index to i64 + %from_elements_1743 = tensor.from_elements %4936, %c1_i64 : tensor<2xi64> + %4937 = stablehlo.dynamic_reshape %4907, %from_elements_1743 : (tensor, tensor<2xi64>) -> tensor + %dim_1744 = tensor.dim %4904, %c0 : tensor + %4938 = arith.index_cast %dim_1744 : index to i64 + %from_elements_1745 = tensor.from_elements %4938, %c1_i64 : tensor<2xi64> + %4939 = stablehlo.dynamic_reshape %4904, %from_elements_1745 : (tensor, tensor<2xi64>) -> tensor + %4940 = stablehlo.concatenate %4937, %4939, dim = 1 : (tensor, tensor) -> tensor + %4941 = "stablehlo.gather"(%4494, %4940) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %4942 = shape.shape_of %4935 : tensor -> tensor<2xindex> + %4943 = shape.shape_of %4941 : tensor -> tensor<2xindex> + %4944 = shape.cstr_broadcastable %4942, %4943 : tensor<2xindex>, tensor<2xindex> + %4945 = shape.assuming %4944 -> (tensor) { + %19688 = shape.broadcast %4942, %4943 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %4935, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %4941, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %4946 = shape.shape_of %4945 : tensor -> tensor<2xindex> + %4947 = stablehlo.dynamic_broadcast_in_dim %4945, %4946, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %4948 = stablehlo.dynamic_broadcast_in_dim %213, %4946, dims = [] : (tensor, tensor<2xindex>) -> tensor + %4949 = stablehlo.multiply %4947, %4948 : tensor + %dim_1746 = tensor.dim %4909, %c0 : tensor + %4950 = arith.index_cast %dim_1746 : index to i64 + %dim_1747 = tensor.dim %4945, %c0 : tensor + %4951 = arith.index_cast %dim_1747 : index to i64 + %4952 = arith.maxsi %4950, %4951 : i64 + %4953 = arith.index_cast %4952 : i64 to index + %from_elements_1748 = tensor.from_elements %4953, %c4096 : tensor<2xindex> + %4954 = stablehlo.dynamic_broadcast_in_dim %4909, %from_elements_1748, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1749 = tensor.dim %4954, %c0 : tensor + %4955 = arith.index_cast %dim_1749 : index to i64 + %from_elements_1750 = tensor.from_elements %4955, %c4096_i64 : tensor<2xi64> + %4956 = stablehlo.real_dynamic_slice %4949, %c_22, %from_elements_1750, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1751 = tensor.from_elements %4955, %c4096_i64, %c1_i64 : tensor<3xi64> + %4957 = stablehlo.dynamic_reshape %4954, %from_elements_1751 : (tensor, tensor<3xi64>) -> tensor + %4958 = stablehlo.dynamic_iota %from_elements_1751, dim = 1 : (tensor<3xi64>) -> tensor + %4959 = stablehlo.concatenate %4957, %4958, dim = 2 : (tensor, tensor) -> tensor + %4960 = "stablehlo.scatter"(%4897, %4959, %4956) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %4961 = stablehlo.reshape %4960 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %4962 = stablehlo.add %4427, %4961 : tensor<3x1x4096xf32> + %4963 = stablehlo.broadcast_in_dim %4962, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %4964 = stablehlo.power %4963, %15 : tensor<3x1x4096xf32> + %4965 = stablehlo.reduce(%4964 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %4966 = stablehlo.reshape %4965 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %4967 = stablehlo.broadcast_in_dim %4966, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %4968 = stablehlo.divide %4967, %21 : tensor<3x1x1xf32> + %4969 = stablehlo.broadcast_in_dim %4968, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %4970 = stablehlo.add %4969, %25 : tensor<3x1x1xf32> + %4971 = stablehlo.rsqrt %4970 : tensor<3x1x1xf32> + %4972 = stablehlo.broadcast_in_dim %4971, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %4973 = stablehlo.multiply %4963, %4972 : tensor<3x1x4096xf32> + %4974 = stablehlo.broadcast_in_dim %4973, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %4975 = stablehlo.multiply %4974, %31 : tensor<3x1x4096xf32> + %4976 = stablehlo.reshape %4975 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %4977 = stablehlo.dot %4976, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %4978 = stablehlo.reshape %4977 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %4979 = stablehlo.dot %4976, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %4980 = stablehlo.reshape %4979 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %4981 = stablehlo.reshape %4978 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %4982 = stablehlo.transpose %4981, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %4983 = stablehlo.reshape %4980 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %4984 = stablehlo.transpose %4983, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %4985 = stablehlo.slice %arg16 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %4986 = stablehlo.slice %arg17 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %4987 = "stablehlo.gather"(%4985, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %4988 = stablehlo.reshape %4987 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %4989 = "stablehlo.gather"(%4986, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %4990 = stablehlo.reshape %4989 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %4991 = stablehlo.broadcast_in_dim %4982, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %4992 = stablehlo.broadcast_in_dim %4988, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %4993 = stablehlo.multiply %4991, %4992 : tensor<3x32x1x128xf32> + %4994 = stablehlo.slice %4982 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %4995 = stablehlo.slice %4982 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %4996 = stablehlo.negate %4995 : tensor<3x32x1x64xf32> + %4997 = stablehlo.concatenate %4996, %4994, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %4998 = stablehlo.broadcast_in_dim %4997, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %4999 = stablehlo.broadcast_in_dim %4990, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %5000 = stablehlo.multiply %4998, %4999 : tensor<3x32x1x128xf32> + %5001 = stablehlo.add %4993, %5000 : tensor<3x32x1x128xf32> + %5002 = stablehlo.broadcast_in_dim %4984, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %5003 = stablehlo.broadcast_in_dim %4988, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %5004 = stablehlo.multiply %5002, %5003 : tensor<3x8x1x128xf32> + %5005 = stablehlo.slice %4984 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %5006 = stablehlo.slice %4984 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %5007 = stablehlo.negate %5006 : tensor<3x8x1x64xf32> + %5008 = stablehlo.concatenate %5007, %5005, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %5009 = stablehlo.broadcast_in_dim %5008, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %5010 = stablehlo.broadcast_in_dim %4990, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %5011 = stablehlo.multiply %5009, %5010 : tensor<3x8x1x128xf32> + %5012 = stablehlo.add %5004, %5011 : tensor<3x8x1x128xf32> + %5013 = stablehlo.concatenate %arg81, %5012, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %5014 = stablehlo.concatenate %arg82, %4984, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %5015 = stablehlo.reshape %5013 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %5016 = stablehlo.broadcast_in_dim %5015, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %5017 = stablehlo.reshape %5016 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %5018 = stablehlo.reshape %5014 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %5019 = stablehlo.broadcast_in_dim %5018, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %5020 = stablehlo.reshape %5019 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %5021 = stablehlo.transpose %5017, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %5022 = stablehlo.reshape %5001 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %5023 = stablehlo.reshape %5021 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %5024 = stablehlo.broadcast_in_dim %5023, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %5025 = stablehlo.dot_general %5022, %5024, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %5026 = stablehlo.reshape %5025 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %5027 = stablehlo.broadcast_in_dim %5026, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %5028 = stablehlo.divide %5027, %89 : tensor<3x32x1x8xf32> + %5029 = stablehlo.custom_call @byteir.softmax(%5028) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %5030 = stablehlo.reshape %5029 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %5031 = stablehlo.reshape %5020 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %5032 = stablehlo.broadcast_in_dim %5031, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %5033 = stablehlo.dot_general %5030, %5032, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %5034 = stablehlo.reshape %5033 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %5035 = stablehlo.transpose %5034, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %5036 = stablehlo.reshape %5035 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %5037 = stablehlo.reshape %5036 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %5038 = stablehlo.dot %5037, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %5039 = stablehlo.reshape %5038 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %5040 = stablehlo.add %4962, %5039 : tensor<3x1x4096xf32> + %5041 = stablehlo.broadcast_in_dim %5040, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %5042 = stablehlo.power %5041, %15 : tensor<3x1x4096xf32> + %5043 = stablehlo.reduce(%5042 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %5044 = stablehlo.reshape %5043 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %5045 = stablehlo.broadcast_in_dim %5044, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %5046 = stablehlo.divide %5045, %21 : tensor<3x1x1xf32> + %5047 = stablehlo.broadcast_in_dim %5046, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %5048 = stablehlo.add %5047, %25 : tensor<3x1x1xf32> + %5049 = stablehlo.rsqrt %5048 : tensor<3x1x1xf32> + %5050 = stablehlo.broadcast_in_dim %5049, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %5051 = stablehlo.multiply %5041, %5050 : tensor<3x1x4096xf32> + %5052 = stablehlo.broadcast_in_dim %5051, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %5053 = stablehlo.multiply %5052, %31 : tensor<3x1x4096xf32> + %5054 = stablehlo.reshape %5053 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %5055 = stablehlo.dot %5054, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %5056 = stablehlo.custom_call @byteir.softmax(%5055) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %5057:2 = stablehlo.custom_call @byteir.top_k(%5056) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %5058 = stablehlo.reduce(%5057#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %5059 = stablehlo.reshape %5058 : (tensor<3xf32>) -> tensor<3x1xf32> + %5060 = stablehlo.broadcast_in_dim %5057#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %5061 = stablehlo.broadcast_in_dim %5059, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %5062 = stablehlo.divide %5060, %5061 : tensor<3x2xf32> + %5063 = stablehlo.reshape %5057#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %5064 = stablehlo.broadcast_in_dim %5063, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %5065 = stablehlo.compare EQ, %5064, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %5066 = stablehlo.convert %5065 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %5067 = stablehlo.transpose %5066, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %5068 = stablehlo.slice %5067 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5069 = stablehlo.reshape %5068 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5070 = stablehlo.custom_call @byteir.non_zero(%5069) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1752 = tensor.dim %5070, %c0 : tensor + %5071 = arith.index_cast %dim_1752 : index to i64 + %from_elements_1753 = tensor.from_elements %5071, %c1_i64 : tensor<2xi64> + %5072 = stablehlo.real_dynamic_slice %5070, %c_22, %from_elements_1753, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1754 = tensor.dim %5072, %c0 : tensor + %5073 = arith.index_cast %dim_1754 : index to i64 + %from_elements_1755 = tensor.from_elements %5073 : tensor<1xi64> + %5074 = stablehlo.dynamic_reshape %5072, %from_elements_1755 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1756 = tensor.from_elements %5071, %c2_i64 : tensor<2xi64> + %5075 = stablehlo.real_dynamic_slice %5070, %c_24, %from_elements_1756, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1757 = tensor.dim %5075, %c0 : tensor + %5076 = arith.index_cast %dim_1757 : index to i64 + %from_elements_1758 = tensor.from_elements %5076 : tensor<1xi64> + %5077 = stablehlo.dynamic_reshape %5075, %from_elements_1758 : (tensor, tensor<1xi64>) -> tensor + %5078 = stablehlo.reshape %5054 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_1759 = tensor.dim %5077, %c0 : tensor + %5079 = arith.index_cast %dim_1759 : index to i64 + %from_elements_1760 = tensor.from_elements %5079, %c1_i64 : tensor<2xi64> + %5080 = stablehlo.dynamic_reshape %5077, %from_elements_1760 : (tensor, tensor<2xi64>) -> tensor + %dim_1761 = tensor.dim %5080, %c0 : tensor + %5081 = arith.index_cast %dim_1761 : index to i64 + %from_elements_1762 = tensor.from_elements %c1_i64, %5081, %c4096_i64 : tensor<3xi64> + %5082 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1762, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1763 = tensor.dim %5082, %c1 : tensor<1x?x4096xi64> + %5083 = arith.index_cast %dim_1763 : index to i64 + %from_elements_1764 = tensor.from_elements %c1_i64, %5083, %c4096_i64, %c1_i64 : tensor<4xi64> + %5084 = stablehlo.dynamic_reshape %5082, %from_elements_1764 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5085 = stablehlo.dynamic_broadcast_in_dim %5080, %from_elements_1762, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1765 = tensor.dim %5085, %c1 : tensor<1x?x4096xi64> + %5086 = arith.index_cast %dim_1765 : index to i64 + %from_elements_1766 = tensor.from_elements %c1_i64, %5086, %c4096_i64, %c1_i64 : tensor<4xi64> + %5087 = stablehlo.dynamic_reshape %5085, %from_elements_1766 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5088 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1762, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1767 = tensor.dim %5088, %c1 : tensor<1x?x4096xi64> + %5089 = arith.index_cast %dim_1767 : index to i64 + %from_elements_1768 = tensor.from_elements %c1_i64, %5089, %c4096_i64, %c1_i64 : tensor<4xi64> + %5090 = stablehlo.dynamic_reshape %5088, %from_elements_1768 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5091 = stablehlo.concatenate %5084, %5087, %5090, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5092 = "stablehlo.gather"(%5078, %5091) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5093 = shape.shape_of %5092 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5094 = shape.num_elements %5093 : tensor<3xindex> -> index + %5095 = stablehlo.compute_reshape_shape %5094, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5096 = stablehlo.dynamic_reshape %5092, %5095 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5097 = stablehlo.dot %5096, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5098 = stablehlo.logistic %5097 : tensor + %5099 = shape.shape_of %5098 : tensor -> tensor<2xindex> + %5100 = shape.shape_of %5097 : tensor -> tensor<2xindex> + %5101 = shape.cstr_broadcastable %5099, %5100 : tensor<2xindex>, tensor<2xindex> + %5102 = shape.assuming %5101 -> (tensor) { + %19688 = shape.broadcast %5099, %5100 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5098, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5097, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5103 = shape.shape_of %5102 : tensor -> tensor<2xindex> + %5104 = shape.cstr_broadcastable %5103, %5100 : tensor<2xindex>, tensor<2xindex> + %5105 = shape.assuming %5104 -> (tensor) { + %19688 = shape.broadcast %5103, %5100 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5102, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5097, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5106 = stablehlo.dot %5105, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %5107 = stablehlo.reshape %5062 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_1769 = tensor.dim %5077, %c0 : tensor + %5108 = arith.index_cast %dim_1769 : index to i64 + %from_elements_1770 = tensor.from_elements %5108, %c1_i64 : tensor<2xi64> + %5109 = stablehlo.dynamic_reshape %5077, %from_elements_1770 : (tensor, tensor<2xi64>) -> tensor + %dim_1771 = tensor.dim %5074, %c0 : tensor + %5110 = arith.index_cast %dim_1771 : index to i64 + %from_elements_1772 = tensor.from_elements %5110, %c1_i64 : tensor<2xi64> + %5111 = stablehlo.dynamic_reshape %5074, %from_elements_1772 : (tensor, tensor<2xi64>) -> tensor + %5112 = stablehlo.concatenate %5109, %5111, dim = 1 : (tensor, tensor) -> tensor + %5113 = "stablehlo.gather"(%5107, %5112) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5114 = shape.shape_of %5106 : tensor -> tensor<2xindex> + %5115 = shape.shape_of %5113 : tensor -> tensor<2xindex> + %5116 = shape.cstr_broadcastable %5114, %5115 : tensor<2xindex>, tensor<2xindex> + %5117 = shape.assuming %5116 -> (tensor) { + %19688 = shape.broadcast %5114, %5115 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5106, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5113, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5118 = shape.shape_of %5117 : tensor -> tensor<2xindex> + %5119 = stablehlo.dynamic_broadcast_in_dim %5117, %5118, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5120 = stablehlo.dynamic_broadcast_in_dim %213, %5118, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5121 = stablehlo.multiply %5119, %5120 : tensor + %dim_1773 = tensor.dim %5080, %c0 : tensor + %5122 = arith.index_cast %dim_1773 : index to i64 + %dim_1774 = tensor.dim %5117, %c0 : tensor + %5123 = arith.index_cast %dim_1774 : index to i64 + %5124 = arith.maxsi %5122, %5123 : i64 + %5125 = arith.index_cast %5124 : i64 to index + %from_elements_1775 = tensor.from_elements %5125, %c4096 : tensor<2xindex> + %5126 = stablehlo.dynamic_broadcast_in_dim %5080, %from_elements_1775, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1776 = tensor.dim %5126, %c0 : tensor + %5127 = arith.index_cast %dim_1776 : index to i64 + %from_elements_1777 = tensor.from_elements %5127, %c4096_i64 : tensor<2xi64> + %5128 = stablehlo.real_dynamic_slice %5121, %c_22, %from_elements_1777, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1778 = tensor.from_elements %5127, %c4096_i64, %c1_i64 : tensor<3xi64> + %5129 = stablehlo.dynamic_reshape %5126, %from_elements_1778 : (tensor, tensor<3xi64>) -> tensor + %5130 = stablehlo.dynamic_iota %from_elements_1778, dim = 1 : (tensor<3xi64>) -> tensor + %5131 = stablehlo.concatenate %5129, %5130, dim = 2 : (tensor, tensor) -> tensor + %5132 = "stablehlo.scatter"(%cst_2, %5131, %5128) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5133 = stablehlo.slice %5067 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5134 = stablehlo.reshape %5133 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5135 = stablehlo.custom_call @byteir.non_zero(%5134) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1779 = tensor.dim %5135, %c0 : tensor + %5136 = arith.index_cast %dim_1779 : index to i64 + %from_elements_1780 = tensor.from_elements %5136, %c1_i64 : tensor<2xi64> + %5137 = stablehlo.real_dynamic_slice %5135, %c_22, %from_elements_1780, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1781 = tensor.dim %5137, %c0 : tensor + %5138 = arith.index_cast %dim_1781 : index to i64 + %from_elements_1782 = tensor.from_elements %5138 : tensor<1xi64> + %5139 = stablehlo.dynamic_reshape %5137, %from_elements_1782 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1783 = tensor.from_elements %5136, %c2_i64 : tensor<2xi64> + %5140 = stablehlo.real_dynamic_slice %5135, %c_24, %from_elements_1783, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1784 = tensor.dim %5140, %c0 : tensor + %5141 = arith.index_cast %dim_1784 : index to i64 + %from_elements_1785 = tensor.from_elements %5141 : tensor<1xi64> + %5142 = stablehlo.dynamic_reshape %5140, %from_elements_1785 : (tensor, tensor<1xi64>) -> tensor + %dim_1786 = tensor.dim %5142, %c0 : tensor + %5143 = arith.index_cast %dim_1786 : index to i64 + %from_elements_1787 = tensor.from_elements %5143, %c1_i64 : tensor<2xi64> + %5144 = stablehlo.dynamic_reshape %5142, %from_elements_1787 : (tensor, tensor<2xi64>) -> tensor + %dim_1788 = tensor.dim %5144, %c0 : tensor + %5145 = arith.index_cast %dim_1788 : index to i64 + %from_elements_1789 = tensor.from_elements %c1_i64, %5145, %c4096_i64 : tensor<3xi64> + %5146 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1789, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1790 = tensor.dim %5146, %c1 : tensor<1x?x4096xi64> + %5147 = arith.index_cast %dim_1790 : index to i64 + %from_elements_1791 = tensor.from_elements %c1_i64, %5147, %c4096_i64, %c1_i64 : tensor<4xi64> + %5148 = stablehlo.dynamic_reshape %5146, %from_elements_1791 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5149 = stablehlo.dynamic_broadcast_in_dim %5144, %from_elements_1789, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1792 = tensor.dim %5149, %c1 : tensor<1x?x4096xi64> + %5150 = arith.index_cast %dim_1792 : index to i64 + %from_elements_1793 = tensor.from_elements %c1_i64, %5150, %c4096_i64, %c1_i64 : tensor<4xi64> + %5151 = stablehlo.dynamic_reshape %5149, %from_elements_1793 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5152 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1789, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1794 = tensor.dim %5152, %c1 : tensor<1x?x4096xi64> + %5153 = arith.index_cast %dim_1794 : index to i64 + %from_elements_1795 = tensor.from_elements %c1_i64, %5153, %c4096_i64, %c1_i64 : tensor<4xi64> + %5154 = stablehlo.dynamic_reshape %5152, %from_elements_1795 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5155 = stablehlo.concatenate %5148, %5151, %5154, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5156 = "stablehlo.gather"(%5078, %5155) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5157 = shape.shape_of %5156 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5158 = shape.num_elements %5157 : tensor<3xindex> -> index + %5159 = stablehlo.compute_reshape_shape %5158, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5160 = stablehlo.dynamic_reshape %5156, %5159 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5161 = stablehlo.dot %5160, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5162 = stablehlo.logistic %5161 : tensor + %5163 = shape.shape_of %5162 : tensor -> tensor<2xindex> + %5164 = shape.shape_of %5161 : tensor -> tensor<2xindex> + %5165 = shape.cstr_broadcastable %5163, %5164 : tensor<2xindex>, tensor<2xindex> + %5166 = shape.assuming %5165 -> (tensor) { + %19688 = shape.broadcast %5163, %5164 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5162, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5161, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5167 = shape.shape_of %5166 : tensor -> tensor<2xindex> + %5168 = shape.cstr_broadcastable %5167, %5164 : tensor<2xindex>, tensor<2xindex> + %5169 = shape.assuming %5168 -> (tensor) { + %19688 = shape.broadcast %5167, %5164 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5166, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5161, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5170 = stablehlo.dot %5169, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1796 = tensor.dim %5142, %c0 : tensor + %5171 = arith.index_cast %dim_1796 : index to i64 + %from_elements_1797 = tensor.from_elements %5171, %c1_i64 : tensor<2xi64> + %5172 = stablehlo.dynamic_reshape %5142, %from_elements_1797 : (tensor, tensor<2xi64>) -> tensor + %dim_1798 = tensor.dim %5139, %c0 : tensor + %5173 = arith.index_cast %dim_1798 : index to i64 + %from_elements_1799 = tensor.from_elements %5173, %c1_i64 : tensor<2xi64> + %5174 = stablehlo.dynamic_reshape %5139, %from_elements_1799 : (tensor, tensor<2xi64>) -> tensor + %5175 = stablehlo.concatenate %5172, %5174, dim = 1 : (tensor, tensor) -> tensor + %5176 = "stablehlo.gather"(%5107, %5175) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5177 = shape.shape_of %5170 : tensor -> tensor<2xindex> + %5178 = shape.shape_of %5176 : tensor -> tensor<2xindex> + %5179 = shape.cstr_broadcastable %5177, %5178 : tensor<2xindex>, tensor<2xindex> + %5180 = shape.assuming %5179 -> (tensor) { + %19688 = shape.broadcast %5177, %5178 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5170, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5176, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5181 = shape.shape_of %5180 : tensor -> tensor<2xindex> + %5182 = stablehlo.dynamic_broadcast_in_dim %5180, %5181, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5183 = stablehlo.dynamic_broadcast_in_dim %213, %5181, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5184 = stablehlo.multiply %5182, %5183 : tensor + %dim_1800 = tensor.dim %5144, %c0 : tensor + %5185 = arith.index_cast %dim_1800 : index to i64 + %dim_1801 = tensor.dim %5180, %c0 : tensor + %5186 = arith.index_cast %dim_1801 : index to i64 + %5187 = arith.maxsi %5185, %5186 : i64 + %5188 = arith.index_cast %5187 : i64 to index + %from_elements_1802 = tensor.from_elements %5188, %c4096 : tensor<2xindex> + %5189 = stablehlo.dynamic_broadcast_in_dim %5144, %from_elements_1802, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1803 = tensor.dim %5189, %c0 : tensor + %5190 = arith.index_cast %dim_1803 : index to i64 + %from_elements_1804 = tensor.from_elements %5190, %c4096_i64 : tensor<2xi64> + %5191 = stablehlo.real_dynamic_slice %5184, %c_22, %from_elements_1804, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1805 = tensor.from_elements %5190, %c4096_i64, %c1_i64 : tensor<3xi64> + %5192 = stablehlo.dynamic_reshape %5189, %from_elements_1805 : (tensor, tensor<3xi64>) -> tensor + %5193 = stablehlo.dynamic_iota %from_elements_1805, dim = 1 : (tensor<3xi64>) -> tensor + %5194 = stablehlo.concatenate %5192, %5193, dim = 2 : (tensor, tensor) -> tensor + %5195 = "stablehlo.scatter"(%5132, %5194, %5191) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5196 = stablehlo.slice %5067 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5197 = stablehlo.reshape %5196 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5198 = stablehlo.custom_call @byteir.non_zero(%5197) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1806 = tensor.dim %5198, %c0 : tensor + %5199 = arith.index_cast %dim_1806 : index to i64 + %from_elements_1807 = tensor.from_elements %5199, %c1_i64 : tensor<2xi64> + %5200 = stablehlo.real_dynamic_slice %5198, %c_22, %from_elements_1807, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1808 = tensor.dim %5200, %c0 : tensor + %5201 = arith.index_cast %dim_1808 : index to i64 + %from_elements_1809 = tensor.from_elements %5201 : tensor<1xi64> + %5202 = stablehlo.dynamic_reshape %5200, %from_elements_1809 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1810 = tensor.from_elements %5199, %c2_i64 : tensor<2xi64> + %5203 = stablehlo.real_dynamic_slice %5198, %c_24, %from_elements_1810, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1811 = tensor.dim %5203, %c0 : tensor + %5204 = arith.index_cast %dim_1811 : index to i64 + %from_elements_1812 = tensor.from_elements %5204 : tensor<1xi64> + %5205 = stablehlo.dynamic_reshape %5203, %from_elements_1812 : (tensor, tensor<1xi64>) -> tensor + %dim_1813 = tensor.dim %5205, %c0 : tensor + %5206 = arith.index_cast %dim_1813 : index to i64 + %from_elements_1814 = tensor.from_elements %5206, %c1_i64 : tensor<2xi64> + %5207 = stablehlo.dynamic_reshape %5205, %from_elements_1814 : (tensor, tensor<2xi64>) -> tensor + %dim_1815 = tensor.dim %5207, %c0 : tensor + %5208 = arith.index_cast %dim_1815 : index to i64 + %from_elements_1816 = tensor.from_elements %c1_i64, %5208, %c4096_i64 : tensor<3xi64> + %5209 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1816, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1817 = tensor.dim %5209, %c1 : tensor<1x?x4096xi64> + %5210 = arith.index_cast %dim_1817 : index to i64 + %from_elements_1818 = tensor.from_elements %c1_i64, %5210, %c4096_i64, %c1_i64 : tensor<4xi64> + %5211 = stablehlo.dynamic_reshape %5209, %from_elements_1818 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5212 = stablehlo.dynamic_broadcast_in_dim %5207, %from_elements_1816, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1819 = tensor.dim %5212, %c1 : tensor<1x?x4096xi64> + %5213 = arith.index_cast %dim_1819 : index to i64 + %from_elements_1820 = tensor.from_elements %c1_i64, %5213, %c4096_i64, %c1_i64 : tensor<4xi64> + %5214 = stablehlo.dynamic_reshape %5212, %from_elements_1820 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5215 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1816, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1821 = tensor.dim %5215, %c1 : tensor<1x?x4096xi64> + %5216 = arith.index_cast %dim_1821 : index to i64 + %from_elements_1822 = tensor.from_elements %c1_i64, %5216, %c4096_i64, %c1_i64 : tensor<4xi64> + %5217 = stablehlo.dynamic_reshape %5215, %from_elements_1822 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5218 = stablehlo.concatenate %5211, %5214, %5217, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5219 = "stablehlo.gather"(%5078, %5218) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5220 = shape.shape_of %5219 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5221 = shape.num_elements %5220 : tensor<3xindex> -> index + %5222 = stablehlo.compute_reshape_shape %5221, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5223 = stablehlo.dynamic_reshape %5219, %5222 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5224 = stablehlo.dot %5223, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5225 = stablehlo.logistic %5224 : tensor + %5226 = shape.shape_of %5225 : tensor -> tensor<2xindex> + %5227 = shape.shape_of %5224 : tensor -> tensor<2xindex> + %5228 = shape.cstr_broadcastable %5226, %5227 : tensor<2xindex>, tensor<2xindex> + %5229 = shape.assuming %5228 -> (tensor) { + %19688 = shape.broadcast %5226, %5227 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5225, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5224, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5230 = shape.shape_of %5229 : tensor -> tensor<2xindex> + %5231 = shape.cstr_broadcastable %5230, %5227 : tensor<2xindex>, tensor<2xindex> + %5232 = shape.assuming %5231 -> (tensor) { + %19688 = shape.broadcast %5230, %5227 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5229, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5224, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5233 = stablehlo.dot %5232, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1823 = tensor.dim %5205, %c0 : tensor + %5234 = arith.index_cast %dim_1823 : index to i64 + %from_elements_1824 = tensor.from_elements %5234, %c1_i64 : tensor<2xi64> + %5235 = stablehlo.dynamic_reshape %5205, %from_elements_1824 : (tensor, tensor<2xi64>) -> tensor + %dim_1825 = tensor.dim %5202, %c0 : tensor + %5236 = arith.index_cast %dim_1825 : index to i64 + %from_elements_1826 = tensor.from_elements %5236, %c1_i64 : tensor<2xi64> + %5237 = stablehlo.dynamic_reshape %5202, %from_elements_1826 : (tensor, tensor<2xi64>) -> tensor + %5238 = stablehlo.concatenate %5235, %5237, dim = 1 : (tensor, tensor) -> tensor + %5239 = "stablehlo.gather"(%5107, %5238) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5240 = shape.shape_of %5233 : tensor -> tensor<2xindex> + %5241 = shape.shape_of %5239 : tensor -> tensor<2xindex> + %5242 = shape.cstr_broadcastable %5240, %5241 : tensor<2xindex>, tensor<2xindex> + %5243 = shape.assuming %5242 -> (tensor) { + %19688 = shape.broadcast %5240, %5241 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5233, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5239, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5244 = shape.shape_of %5243 : tensor -> tensor<2xindex> + %5245 = stablehlo.dynamic_broadcast_in_dim %5243, %5244, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5246 = stablehlo.dynamic_broadcast_in_dim %213, %5244, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5247 = stablehlo.multiply %5245, %5246 : tensor + %dim_1827 = tensor.dim %5207, %c0 : tensor + %5248 = arith.index_cast %dim_1827 : index to i64 + %dim_1828 = tensor.dim %5243, %c0 : tensor + %5249 = arith.index_cast %dim_1828 : index to i64 + %5250 = arith.maxsi %5248, %5249 : i64 + %5251 = arith.index_cast %5250 : i64 to index + %from_elements_1829 = tensor.from_elements %5251, %c4096 : tensor<2xindex> + %5252 = stablehlo.dynamic_broadcast_in_dim %5207, %from_elements_1829, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1830 = tensor.dim %5252, %c0 : tensor + %5253 = arith.index_cast %dim_1830 : index to i64 + %from_elements_1831 = tensor.from_elements %5253, %c4096_i64 : tensor<2xi64> + %5254 = stablehlo.real_dynamic_slice %5247, %c_22, %from_elements_1831, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1832 = tensor.from_elements %5253, %c4096_i64, %c1_i64 : tensor<3xi64> + %5255 = stablehlo.dynamic_reshape %5252, %from_elements_1832 : (tensor, tensor<3xi64>) -> tensor + %5256 = stablehlo.dynamic_iota %from_elements_1832, dim = 1 : (tensor<3xi64>) -> tensor + %5257 = stablehlo.concatenate %5255, %5256, dim = 2 : (tensor, tensor) -> tensor + %5258 = "stablehlo.scatter"(%5195, %5257, %5254) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5259 = stablehlo.slice %5067 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5260 = stablehlo.reshape %5259 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5261 = stablehlo.custom_call @byteir.non_zero(%5260) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1833 = tensor.dim %5261, %c0 : tensor + %5262 = arith.index_cast %dim_1833 : index to i64 + %from_elements_1834 = tensor.from_elements %5262, %c1_i64 : tensor<2xi64> + %5263 = stablehlo.real_dynamic_slice %5261, %c_22, %from_elements_1834, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1835 = tensor.dim %5263, %c0 : tensor + %5264 = arith.index_cast %dim_1835 : index to i64 + %from_elements_1836 = tensor.from_elements %5264 : tensor<1xi64> + %5265 = stablehlo.dynamic_reshape %5263, %from_elements_1836 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1837 = tensor.from_elements %5262, %c2_i64 : tensor<2xi64> + %5266 = stablehlo.real_dynamic_slice %5261, %c_24, %from_elements_1837, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1838 = tensor.dim %5266, %c0 : tensor + %5267 = arith.index_cast %dim_1838 : index to i64 + %from_elements_1839 = tensor.from_elements %5267 : tensor<1xi64> + %5268 = stablehlo.dynamic_reshape %5266, %from_elements_1839 : (tensor, tensor<1xi64>) -> tensor + %dim_1840 = tensor.dim %5268, %c0 : tensor + %5269 = arith.index_cast %dim_1840 : index to i64 + %from_elements_1841 = tensor.from_elements %5269, %c1_i64 : tensor<2xi64> + %5270 = stablehlo.dynamic_reshape %5268, %from_elements_1841 : (tensor, tensor<2xi64>) -> tensor + %dim_1842 = tensor.dim %5270, %c0 : tensor + %5271 = arith.index_cast %dim_1842 : index to i64 + %from_elements_1843 = tensor.from_elements %c1_i64, %5271, %c4096_i64 : tensor<3xi64> + %5272 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1843, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1844 = tensor.dim %5272, %c1 : tensor<1x?x4096xi64> + %5273 = arith.index_cast %dim_1844 : index to i64 + %from_elements_1845 = tensor.from_elements %c1_i64, %5273, %c4096_i64, %c1_i64 : tensor<4xi64> + %5274 = stablehlo.dynamic_reshape %5272, %from_elements_1845 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5275 = stablehlo.dynamic_broadcast_in_dim %5270, %from_elements_1843, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1846 = tensor.dim %5275, %c1 : tensor<1x?x4096xi64> + %5276 = arith.index_cast %dim_1846 : index to i64 + %from_elements_1847 = tensor.from_elements %c1_i64, %5276, %c4096_i64, %c1_i64 : tensor<4xi64> + %5277 = stablehlo.dynamic_reshape %5275, %from_elements_1847 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5278 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1843, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1848 = tensor.dim %5278, %c1 : tensor<1x?x4096xi64> + %5279 = arith.index_cast %dim_1848 : index to i64 + %from_elements_1849 = tensor.from_elements %c1_i64, %5279, %c4096_i64, %c1_i64 : tensor<4xi64> + %5280 = stablehlo.dynamic_reshape %5278, %from_elements_1849 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5281 = stablehlo.concatenate %5274, %5277, %5280, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5282 = "stablehlo.gather"(%5078, %5281) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5283 = shape.shape_of %5282 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5284 = shape.num_elements %5283 : tensor<3xindex> -> index + %5285 = stablehlo.compute_reshape_shape %5284, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5286 = stablehlo.dynamic_reshape %5282, %5285 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5287 = stablehlo.dot %5286, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5288 = stablehlo.logistic %5287 : tensor + %5289 = shape.shape_of %5288 : tensor -> tensor<2xindex> + %5290 = shape.shape_of %5287 : tensor -> tensor<2xindex> + %5291 = shape.cstr_broadcastable %5289, %5290 : tensor<2xindex>, tensor<2xindex> + %5292 = shape.assuming %5291 -> (tensor) { + %19688 = shape.broadcast %5289, %5290 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5288, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5287, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5293 = shape.shape_of %5292 : tensor -> tensor<2xindex> + %5294 = shape.cstr_broadcastable %5293, %5290 : tensor<2xindex>, tensor<2xindex> + %5295 = shape.assuming %5294 -> (tensor) { + %19688 = shape.broadcast %5293, %5290 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5292, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5287, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5296 = stablehlo.dot %5295, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1850 = tensor.dim %5268, %c0 : tensor + %5297 = arith.index_cast %dim_1850 : index to i64 + %from_elements_1851 = tensor.from_elements %5297, %c1_i64 : tensor<2xi64> + %5298 = stablehlo.dynamic_reshape %5268, %from_elements_1851 : (tensor, tensor<2xi64>) -> tensor + %dim_1852 = tensor.dim %5265, %c0 : tensor + %5299 = arith.index_cast %dim_1852 : index to i64 + %from_elements_1853 = tensor.from_elements %5299, %c1_i64 : tensor<2xi64> + %5300 = stablehlo.dynamic_reshape %5265, %from_elements_1853 : (tensor, tensor<2xi64>) -> tensor + %5301 = stablehlo.concatenate %5298, %5300, dim = 1 : (tensor, tensor) -> tensor + %5302 = "stablehlo.gather"(%5107, %5301) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5303 = shape.shape_of %5296 : tensor -> tensor<2xindex> + %5304 = shape.shape_of %5302 : tensor -> tensor<2xindex> + %5305 = shape.cstr_broadcastable %5303, %5304 : tensor<2xindex>, tensor<2xindex> + %5306 = shape.assuming %5305 -> (tensor) { + %19688 = shape.broadcast %5303, %5304 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5296, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5302, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5307 = shape.shape_of %5306 : tensor -> tensor<2xindex> + %5308 = stablehlo.dynamic_broadcast_in_dim %5306, %5307, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5309 = stablehlo.dynamic_broadcast_in_dim %213, %5307, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5310 = stablehlo.multiply %5308, %5309 : tensor + %dim_1854 = tensor.dim %5270, %c0 : tensor + %5311 = arith.index_cast %dim_1854 : index to i64 + %dim_1855 = tensor.dim %5306, %c0 : tensor + %5312 = arith.index_cast %dim_1855 : index to i64 + %5313 = arith.maxsi %5311, %5312 : i64 + %5314 = arith.index_cast %5313 : i64 to index + %from_elements_1856 = tensor.from_elements %5314, %c4096 : tensor<2xindex> + %5315 = stablehlo.dynamic_broadcast_in_dim %5270, %from_elements_1856, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1857 = tensor.dim %5315, %c0 : tensor + %5316 = arith.index_cast %dim_1857 : index to i64 + %from_elements_1858 = tensor.from_elements %5316, %c4096_i64 : tensor<2xi64> + %5317 = stablehlo.real_dynamic_slice %5310, %c_22, %from_elements_1858, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1859 = tensor.from_elements %5316, %c4096_i64, %c1_i64 : tensor<3xi64> + %5318 = stablehlo.dynamic_reshape %5315, %from_elements_1859 : (tensor, tensor<3xi64>) -> tensor + %5319 = stablehlo.dynamic_iota %from_elements_1859, dim = 1 : (tensor<3xi64>) -> tensor + %5320 = stablehlo.concatenate %5318, %5319, dim = 2 : (tensor, tensor) -> tensor + %5321 = "stablehlo.scatter"(%5258, %5320, %5317) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5322 = stablehlo.slice %5067 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5323 = stablehlo.reshape %5322 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5324 = stablehlo.custom_call @byteir.non_zero(%5323) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1860 = tensor.dim %5324, %c0 : tensor + %5325 = arith.index_cast %dim_1860 : index to i64 + %from_elements_1861 = tensor.from_elements %5325, %c1_i64 : tensor<2xi64> + %5326 = stablehlo.real_dynamic_slice %5324, %c_22, %from_elements_1861, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1862 = tensor.dim %5326, %c0 : tensor + %5327 = arith.index_cast %dim_1862 : index to i64 + %from_elements_1863 = tensor.from_elements %5327 : tensor<1xi64> + %5328 = stablehlo.dynamic_reshape %5326, %from_elements_1863 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1864 = tensor.from_elements %5325, %c2_i64 : tensor<2xi64> + %5329 = stablehlo.real_dynamic_slice %5324, %c_24, %from_elements_1864, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1865 = tensor.dim %5329, %c0 : tensor + %5330 = arith.index_cast %dim_1865 : index to i64 + %from_elements_1866 = tensor.from_elements %5330 : tensor<1xi64> + %5331 = stablehlo.dynamic_reshape %5329, %from_elements_1866 : (tensor, tensor<1xi64>) -> tensor + %dim_1867 = tensor.dim %5331, %c0 : tensor + %5332 = arith.index_cast %dim_1867 : index to i64 + %from_elements_1868 = tensor.from_elements %5332, %c1_i64 : tensor<2xi64> + %5333 = stablehlo.dynamic_reshape %5331, %from_elements_1868 : (tensor, tensor<2xi64>) -> tensor + %dim_1869 = tensor.dim %5333, %c0 : tensor + %5334 = arith.index_cast %dim_1869 : index to i64 + %from_elements_1870 = tensor.from_elements %c1_i64, %5334, %c4096_i64 : tensor<3xi64> + %5335 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1870, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1871 = tensor.dim %5335, %c1 : tensor<1x?x4096xi64> + %5336 = arith.index_cast %dim_1871 : index to i64 + %from_elements_1872 = tensor.from_elements %c1_i64, %5336, %c4096_i64, %c1_i64 : tensor<4xi64> + %5337 = stablehlo.dynamic_reshape %5335, %from_elements_1872 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5338 = stablehlo.dynamic_broadcast_in_dim %5333, %from_elements_1870, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1873 = tensor.dim %5338, %c1 : tensor<1x?x4096xi64> + %5339 = arith.index_cast %dim_1873 : index to i64 + %from_elements_1874 = tensor.from_elements %c1_i64, %5339, %c4096_i64, %c1_i64 : tensor<4xi64> + %5340 = stablehlo.dynamic_reshape %5338, %from_elements_1874 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5341 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1870, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1875 = tensor.dim %5341, %c1 : tensor<1x?x4096xi64> + %5342 = arith.index_cast %dim_1875 : index to i64 + %from_elements_1876 = tensor.from_elements %c1_i64, %5342, %c4096_i64, %c1_i64 : tensor<4xi64> + %5343 = stablehlo.dynamic_reshape %5341, %from_elements_1876 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5344 = stablehlo.concatenate %5337, %5340, %5343, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5345 = "stablehlo.gather"(%5078, %5344) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5346 = shape.shape_of %5345 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5347 = shape.num_elements %5346 : tensor<3xindex> -> index + %5348 = stablehlo.compute_reshape_shape %5347, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5349 = stablehlo.dynamic_reshape %5345, %5348 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5350 = stablehlo.dot %5349, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5351 = stablehlo.logistic %5350 : tensor + %5352 = shape.shape_of %5351 : tensor -> tensor<2xindex> + %5353 = shape.shape_of %5350 : tensor -> tensor<2xindex> + %5354 = shape.cstr_broadcastable %5352, %5353 : tensor<2xindex>, tensor<2xindex> + %5355 = shape.assuming %5354 -> (tensor) { + %19688 = shape.broadcast %5352, %5353 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5351, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5350, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5356 = shape.shape_of %5355 : tensor -> tensor<2xindex> + %5357 = shape.cstr_broadcastable %5356, %5353 : tensor<2xindex>, tensor<2xindex> + %5358 = shape.assuming %5357 -> (tensor) { + %19688 = shape.broadcast %5356, %5353 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5355, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5350, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5359 = stablehlo.dot %5358, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1877 = tensor.dim %5331, %c0 : tensor + %5360 = arith.index_cast %dim_1877 : index to i64 + %from_elements_1878 = tensor.from_elements %5360, %c1_i64 : tensor<2xi64> + %5361 = stablehlo.dynamic_reshape %5331, %from_elements_1878 : (tensor, tensor<2xi64>) -> tensor + %dim_1879 = tensor.dim %5328, %c0 : tensor + %5362 = arith.index_cast %dim_1879 : index to i64 + %from_elements_1880 = tensor.from_elements %5362, %c1_i64 : tensor<2xi64> + %5363 = stablehlo.dynamic_reshape %5328, %from_elements_1880 : (tensor, tensor<2xi64>) -> tensor + %5364 = stablehlo.concatenate %5361, %5363, dim = 1 : (tensor, tensor) -> tensor + %5365 = "stablehlo.gather"(%5107, %5364) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5366 = shape.shape_of %5359 : tensor -> tensor<2xindex> + %5367 = shape.shape_of %5365 : tensor -> tensor<2xindex> + %5368 = shape.cstr_broadcastable %5366, %5367 : tensor<2xindex>, tensor<2xindex> + %5369 = shape.assuming %5368 -> (tensor) { + %19688 = shape.broadcast %5366, %5367 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5359, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5365, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5370 = shape.shape_of %5369 : tensor -> tensor<2xindex> + %5371 = stablehlo.dynamic_broadcast_in_dim %5369, %5370, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5372 = stablehlo.dynamic_broadcast_in_dim %213, %5370, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5373 = stablehlo.multiply %5371, %5372 : tensor + %dim_1881 = tensor.dim %5333, %c0 : tensor + %5374 = arith.index_cast %dim_1881 : index to i64 + %dim_1882 = tensor.dim %5369, %c0 : tensor + %5375 = arith.index_cast %dim_1882 : index to i64 + %5376 = arith.maxsi %5374, %5375 : i64 + %5377 = arith.index_cast %5376 : i64 to index + %from_elements_1883 = tensor.from_elements %5377, %c4096 : tensor<2xindex> + %5378 = stablehlo.dynamic_broadcast_in_dim %5333, %from_elements_1883, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1884 = tensor.dim %5378, %c0 : tensor + %5379 = arith.index_cast %dim_1884 : index to i64 + %from_elements_1885 = tensor.from_elements %5379, %c4096_i64 : tensor<2xi64> + %5380 = stablehlo.real_dynamic_slice %5373, %c_22, %from_elements_1885, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1886 = tensor.from_elements %5379, %c4096_i64, %c1_i64 : tensor<3xi64> + %5381 = stablehlo.dynamic_reshape %5378, %from_elements_1886 : (tensor, tensor<3xi64>) -> tensor + %5382 = stablehlo.dynamic_iota %from_elements_1886, dim = 1 : (tensor<3xi64>) -> tensor + %5383 = stablehlo.concatenate %5381, %5382, dim = 2 : (tensor, tensor) -> tensor + %5384 = "stablehlo.scatter"(%5321, %5383, %5380) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5385 = stablehlo.slice %5067 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5386 = stablehlo.reshape %5385 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5387 = stablehlo.custom_call @byteir.non_zero(%5386) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1887 = tensor.dim %5387, %c0 : tensor + %5388 = arith.index_cast %dim_1887 : index to i64 + %from_elements_1888 = tensor.from_elements %5388, %c1_i64 : tensor<2xi64> + %5389 = stablehlo.real_dynamic_slice %5387, %c_22, %from_elements_1888, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1889 = tensor.dim %5389, %c0 : tensor + %5390 = arith.index_cast %dim_1889 : index to i64 + %from_elements_1890 = tensor.from_elements %5390 : tensor<1xi64> + %5391 = stablehlo.dynamic_reshape %5389, %from_elements_1890 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1891 = tensor.from_elements %5388, %c2_i64 : tensor<2xi64> + %5392 = stablehlo.real_dynamic_slice %5387, %c_24, %from_elements_1891, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1892 = tensor.dim %5392, %c0 : tensor + %5393 = arith.index_cast %dim_1892 : index to i64 + %from_elements_1893 = tensor.from_elements %5393 : tensor<1xi64> + %5394 = stablehlo.dynamic_reshape %5392, %from_elements_1893 : (tensor, tensor<1xi64>) -> tensor + %dim_1894 = tensor.dim %5394, %c0 : tensor + %5395 = arith.index_cast %dim_1894 : index to i64 + %from_elements_1895 = tensor.from_elements %5395, %c1_i64 : tensor<2xi64> + %5396 = stablehlo.dynamic_reshape %5394, %from_elements_1895 : (tensor, tensor<2xi64>) -> tensor + %dim_1896 = tensor.dim %5396, %c0 : tensor + %5397 = arith.index_cast %dim_1896 : index to i64 + %from_elements_1897 = tensor.from_elements %c1_i64, %5397, %c4096_i64 : tensor<3xi64> + %5398 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1897, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1898 = tensor.dim %5398, %c1 : tensor<1x?x4096xi64> + %5399 = arith.index_cast %dim_1898 : index to i64 + %from_elements_1899 = tensor.from_elements %c1_i64, %5399, %c4096_i64, %c1_i64 : tensor<4xi64> + %5400 = stablehlo.dynamic_reshape %5398, %from_elements_1899 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5401 = stablehlo.dynamic_broadcast_in_dim %5396, %from_elements_1897, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1900 = tensor.dim %5401, %c1 : tensor<1x?x4096xi64> + %5402 = arith.index_cast %dim_1900 : index to i64 + %from_elements_1901 = tensor.from_elements %c1_i64, %5402, %c4096_i64, %c1_i64 : tensor<4xi64> + %5403 = stablehlo.dynamic_reshape %5401, %from_elements_1901 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5404 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1897, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1902 = tensor.dim %5404, %c1 : tensor<1x?x4096xi64> + %5405 = arith.index_cast %dim_1902 : index to i64 + %from_elements_1903 = tensor.from_elements %c1_i64, %5405, %c4096_i64, %c1_i64 : tensor<4xi64> + %5406 = stablehlo.dynamic_reshape %5404, %from_elements_1903 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5407 = stablehlo.concatenate %5400, %5403, %5406, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5408 = "stablehlo.gather"(%5078, %5407) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5409 = shape.shape_of %5408 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5410 = shape.num_elements %5409 : tensor<3xindex> -> index + %5411 = stablehlo.compute_reshape_shape %5410, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5412 = stablehlo.dynamic_reshape %5408, %5411 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5413 = stablehlo.dot %5412, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5414 = stablehlo.logistic %5413 : tensor + %5415 = shape.shape_of %5414 : tensor -> tensor<2xindex> + %5416 = shape.shape_of %5413 : tensor -> tensor<2xindex> + %5417 = shape.cstr_broadcastable %5415, %5416 : tensor<2xindex>, tensor<2xindex> + %5418 = shape.assuming %5417 -> (tensor) { + %19688 = shape.broadcast %5415, %5416 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5414, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5413, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5419 = shape.shape_of %5418 : tensor -> tensor<2xindex> + %5420 = shape.cstr_broadcastable %5419, %5416 : tensor<2xindex>, tensor<2xindex> + %5421 = shape.assuming %5420 -> (tensor) { + %19688 = shape.broadcast %5419, %5416 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5418, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5413, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5422 = stablehlo.dot %5421, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1904 = tensor.dim %5394, %c0 : tensor + %5423 = arith.index_cast %dim_1904 : index to i64 + %from_elements_1905 = tensor.from_elements %5423, %c1_i64 : tensor<2xi64> + %5424 = stablehlo.dynamic_reshape %5394, %from_elements_1905 : (tensor, tensor<2xi64>) -> tensor + %dim_1906 = tensor.dim %5391, %c0 : tensor + %5425 = arith.index_cast %dim_1906 : index to i64 + %from_elements_1907 = tensor.from_elements %5425, %c1_i64 : tensor<2xi64> + %5426 = stablehlo.dynamic_reshape %5391, %from_elements_1907 : (tensor, tensor<2xi64>) -> tensor + %5427 = stablehlo.concatenate %5424, %5426, dim = 1 : (tensor, tensor) -> tensor + %5428 = "stablehlo.gather"(%5107, %5427) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5429 = shape.shape_of %5422 : tensor -> tensor<2xindex> + %5430 = shape.shape_of %5428 : tensor -> tensor<2xindex> + %5431 = shape.cstr_broadcastable %5429, %5430 : tensor<2xindex>, tensor<2xindex> + %5432 = shape.assuming %5431 -> (tensor) { + %19688 = shape.broadcast %5429, %5430 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5422, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5428, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5433 = shape.shape_of %5432 : tensor -> tensor<2xindex> + %5434 = stablehlo.dynamic_broadcast_in_dim %5432, %5433, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5435 = stablehlo.dynamic_broadcast_in_dim %213, %5433, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5436 = stablehlo.multiply %5434, %5435 : tensor + %dim_1908 = tensor.dim %5396, %c0 : tensor + %5437 = arith.index_cast %dim_1908 : index to i64 + %dim_1909 = tensor.dim %5432, %c0 : tensor + %5438 = arith.index_cast %dim_1909 : index to i64 + %5439 = arith.maxsi %5437, %5438 : i64 + %5440 = arith.index_cast %5439 : i64 to index + %from_elements_1910 = tensor.from_elements %5440, %c4096 : tensor<2xindex> + %5441 = stablehlo.dynamic_broadcast_in_dim %5396, %from_elements_1910, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1911 = tensor.dim %5441, %c0 : tensor + %5442 = arith.index_cast %dim_1911 : index to i64 + %from_elements_1912 = tensor.from_elements %5442, %c4096_i64 : tensor<2xi64> + %5443 = stablehlo.real_dynamic_slice %5436, %c_22, %from_elements_1912, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1913 = tensor.from_elements %5442, %c4096_i64, %c1_i64 : tensor<3xi64> + %5444 = stablehlo.dynamic_reshape %5441, %from_elements_1913 : (tensor, tensor<3xi64>) -> tensor + %5445 = stablehlo.dynamic_iota %from_elements_1913, dim = 1 : (tensor<3xi64>) -> tensor + %5446 = stablehlo.concatenate %5444, %5445, dim = 2 : (tensor, tensor) -> tensor + %5447 = "stablehlo.scatter"(%5384, %5446, %5443) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5448 = stablehlo.slice %5067 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5449 = stablehlo.reshape %5448 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5450 = stablehlo.custom_call @byteir.non_zero(%5449) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1914 = tensor.dim %5450, %c0 : tensor + %5451 = arith.index_cast %dim_1914 : index to i64 + %from_elements_1915 = tensor.from_elements %5451, %c1_i64 : tensor<2xi64> + %5452 = stablehlo.real_dynamic_slice %5450, %c_22, %from_elements_1915, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1916 = tensor.dim %5452, %c0 : tensor + %5453 = arith.index_cast %dim_1916 : index to i64 + %from_elements_1917 = tensor.from_elements %5453 : tensor<1xi64> + %5454 = stablehlo.dynamic_reshape %5452, %from_elements_1917 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1918 = tensor.from_elements %5451, %c2_i64 : tensor<2xi64> + %5455 = stablehlo.real_dynamic_slice %5450, %c_24, %from_elements_1918, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1919 = tensor.dim %5455, %c0 : tensor + %5456 = arith.index_cast %dim_1919 : index to i64 + %from_elements_1920 = tensor.from_elements %5456 : tensor<1xi64> + %5457 = stablehlo.dynamic_reshape %5455, %from_elements_1920 : (tensor, tensor<1xi64>) -> tensor + %dim_1921 = tensor.dim %5457, %c0 : tensor + %5458 = arith.index_cast %dim_1921 : index to i64 + %from_elements_1922 = tensor.from_elements %5458, %c1_i64 : tensor<2xi64> + %5459 = stablehlo.dynamic_reshape %5457, %from_elements_1922 : (tensor, tensor<2xi64>) -> tensor + %dim_1923 = tensor.dim %5459, %c0 : tensor + %5460 = arith.index_cast %dim_1923 : index to i64 + %from_elements_1924 = tensor.from_elements %c1_i64, %5460, %c4096_i64 : tensor<3xi64> + %5461 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1924, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1925 = tensor.dim %5461, %c1 : tensor<1x?x4096xi64> + %5462 = arith.index_cast %dim_1925 : index to i64 + %from_elements_1926 = tensor.from_elements %c1_i64, %5462, %c4096_i64, %c1_i64 : tensor<4xi64> + %5463 = stablehlo.dynamic_reshape %5461, %from_elements_1926 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5464 = stablehlo.dynamic_broadcast_in_dim %5459, %from_elements_1924, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1927 = tensor.dim %5464, %c1 : tensor<1x?x4096xi64> + %5465 = arith.index_cast %dim_1927 : index to i64 + %from_elements_1928 = tensor.from_elements %c1_i64, %5465, %c4096_i64, %c1_i64 : tensor<4xi64> + %5466 = stablehlo.dynamic_reshape %5464, %from_elements_1928 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5467 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1924, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1929 = tensor.dim %5467, %c1 : tensor<1x?x4096xi64> + %5468 = arith.index_cast %dim_1929 : index to i64 + %from_elements_1930 = tensor.from_elements %c1_i64, %5468, %c4096_i64, %c1_i64 : tensor<4xi64> + %5469 = stablehlo.dynamic_reshape %5467, %from_elements_1930 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5470 = stablehlo.concatenate %5463, %5466, %5469, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5471 = "stablehlo.gather"(%5078, %5470) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5472 = shape.shape_of %5471 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5473 = shape.num_elements %5472 : tensor<3xindex> -> index + %5474 = stablehlo.compute_reshape_shape %5473, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5475 = stablehlo.dynamic_reshape %5471, %5474 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5476 = stablehlo.dot %5475, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5477 = stablehlo.logistic %5476 : tensor + %5478 = shape.shape_of %5477 : tensor -> tensor<2xindex> + %5479 = shape.shape_of %5476 : tensor -> tensor<2xindex> + %5480 = shape.cstr_broadcastable %5478, %5479 : tensor<2xindex>, tensor<2xindex> + %5481 = shape.assuming %5480 -> (tensor) { + %19688 = shape.broadcast %5478, %5479 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5477, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5476, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5482 = shape.shape_of %5481 : tensor -> tensor<2xindex> + %5483 = shape.cstr_broadcastable %5482, %5479 : tensor<2xindex>, tensor<2xindex> + %5484 = shape.assuming %5483 -> (tensor) { + %19688 = shape.broadcast %5482, %5479 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5481, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5476, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5485 = stablehlo.dot %5484, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1931 = tensor.dim %5457, %c0 : tensor + %5486 = arith.index_cast %dim_1931 : index to i64 + %from_elements_1932 = tensor.from_elements %5486, %c1_i64 : tensor<2xi64> + %5487 = stablehlo.dynamic_reshape %5457, %from_elements_1932 : (tensor, tensor<2xi64>) -> tensor + %dim_1933 = tensor.dim %5454, %c0 : tensor + %5488 = arith.index_cast %dim_1933 : index to i64 + %from_elements_1934 = tensor.from_elements %5488, %c1_i64 : tensor<2xi64> + %5489 = stablehlo.dynamic_reshape %5454, %from_elements_1934 : (tensor, tensor<2xi64>) -> tensor + %5490 = stablehlo.concatenate %5487, %5489, dim = 1 : (tensor, tensor) -> tensor + %5491 = "stablehlo.gather"(%5107, %5490) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5492 = shape.shape_of %5485 : tensor -> tensor<2xindex> + %5493 = shape.shape_of %5491 : tensor -> tensor<2xindex> + %5494 = shape.cstr_broadcastable %5492, %5493 : tensor<2xindex>, tensor<2xindex> + %5495 = shape.assuming %5494 -> (tensor) { + %19688 = shape.broadcast %5492, %5493 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5485, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5491, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5496 = shape.shape_of %5495 : tensor -> tensor<2xindex> + %5497 = stablehlo.dynamic_broadcast_in_dim %5495, %5496, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5498 = stablehlo.dynamic_broadcast_in_dim %213, %5496, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5499 = stablehlo.multiply %5497, %5498 : tensor + %dim_1935 = tensor.dim %5459, %c0 : tensor + %5500 = arith.index_cast %dim_1935 : index to i64 + %dim_1936 = tensor.dim %5495, %c0 : tensor + %5501 = arith.index_cast %dim_1936 : index to i64 + %5502 = arith.maxsi %5500, %5501 : i64 + %5503 = arith.index_cast %5502 : i64 to index + %from_elements_1937 = tensor.from_elements %5503, %c4096 : tensor<2xindex> + %5504 = stablehlo.dynamic_broadcast_in_dim %5459, %from_elements_1937, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1938 = tensor.dim %5504, %c0 : tensor + %5505 = arith.index_cast %dim_1938 : index to i64 + %from_elements_1939 = tensor.from_elements %5505, %c4096_i64 : tensor<2xi64> + %5506 = stablehlo.real_dynamic_slice %5499, %c_22, %from_elements_1939, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1940 = tensor.from_elements %5505, %c4096_i64, %c1_i64 : tensor<3xi64> + %5507 = stablehlo.dynamic_reshape %5504, %from_elements_1940 : (tensor, tensor<3xi64>) -> tensor + %5508 = stablehlo.dynamic_iota %from_elements_1940, dim = 1 : (tensor<3xi64>) -> tensor + %5509 = stablehlo.concatenate %5507, %5508, dim = 2 : (tensor, tensor) -> tensor + %5510 = "stablehlo.scatter"(%5447, %5509, %5506) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5511 = stablehlo.slice %5067 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5512 = stablehlo.reshape %5511 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5513 = stablehlo.custom_call @byteir.non_zero(%5512) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1941 = tensor.dim %5513, %c0 : tensor + %5514 = arith.index_cast %dim_1941 : index to i64 + %from_elements_1942 = tensor.from_elements %5514, %c1_i64 : tensor<2xi64> + %5515 = stablehlo.real_dynamic_slice %5513, %c_22, %from_elements_1942, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1943 = tensor.dim %5515, %c0 : tensor + %5516 = arith.index_cast %dim_1943 : index to i64 + %from_elements_1944 = tensor.from_elements %5516 : tensor<1xi64> + %5517 = stablehlo.dynamic_reshape %5515, %from_elements_1944 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1945 = tensor.from_elements %5514, %c2_i64 : tensor<2xi64> + %5518 = stablehlo.real_dynamic_slice %5513, %c_24, %from_elements_1945, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1946 = tensor.dim %5518, %c0 : tensor + %5519 = arith.index_cast %dim_1946 : index to i64 + %from_elements_1947 = tensor.from_elements %5519 : tensor<1xi64> + %5520 = stablehlo.dynamic_reshape %5518, %from_elements_1947 : (tensor, tensor<1xi64>) -> tensor + %dim_1948 = tensor.dim %5520, %c0 : tensor + %5521 = arith.index_cast %dim_1948 : index to i64 + %from_elements_1949 = tensor.from_elements %5521, %c1_i64 : tensor<2xi64> + %5522 = stablehlo.dynamic_reshape %5520, %from_elements_1949 : (tensor, tensor<2xi64>) -> tensor + %dim_1950 = tensor.dim %5522, %c0 : tensor + %5523 = arith.index_cast %dim_1950 : index to i64 + %from_elements_1951 = tensor.from_elements %c1_i64, %5523, %c4096_i64 : tensor<3xi64> + %5524 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1951, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1952 = tensor.dim %5524, %c1 : tensor<1x?x4096xi64> + %5525 = arith.index_cast %dim_1952 : index to i64 + %from_elements_1953 = tensor.from_elements %c1_i64, %5525, %c4096_i64, %c1_i64 : tensor<4xi64> + %5526 = stablehlo.dynamic_reshape %5524, %from_elements_1953 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5527 = stablehlo.dynamic_broadcast_in_dim %5522, %from_elements_1951, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1954 = tensor.dim %5527, %c1 : tensor<1x?x4096xi64> + %5528 = arith.index_cast %dim_1954 : index to i64 + %from_elements_1955 = tensor.from_elements %c1_i64, %5528, %c4096_i64, %c1_i64 : tensor<4xi64> + %5529 = stablehlo.dynamic_reshape %5527, %from_elements_1955 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5530 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1951, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1956 = tensor.dim %5530, %c1 : tensor<1x?x4096xi64> + %5531 = arith.index_cast %dim_1956 : index to i64 + %from_elements_1957 = tensor.from_elements %c1_i64, %5531, %c4096_i64, %c1_i64 : tensor<4xi64> + %5532 = stablehlo.dynamic_reshape %5530, %from_elements_1957 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5533 = stablehlo.concatenate %5526, %5529, %5532, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5534 = "stablehlo.gather"(%5078, %5533) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5535 = shape.shape_of %5534 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5536 = shape.num_elements %5535 : tensor<3xindex> -> index + %5537 = stablehlo.compute_reshape_shape %5536, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5538 = stablehlo.dynamic_reshape %5534, %5537 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5539 = stablehlo.dot %5538, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5540 = stablehlo.logistic %5539 : tensor + %5541 = shape.shape_of %5540 : tensor -> tensor<2xindex> + %5542 = shape.shape_of %5539 : tensor -> tensor<2xindex> + %5543 = shape.cstr_broadcastable %5541, %5542 : tensor<2xindex>, tensor<2xindex> + %5544 = shape.assuming %5543 -> (tensor) { + %19688 = shape.broadcast %5541, %5542 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5540, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5539, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5545 = shape.shape_of %5544 : tensor -> tensor<2xindex> + %5546 = shape.cstr_broadcastable %5545, %5542 : tensor<2xindex>, tensor<2xindex> + %5547 = shape.assuming %5546 -> (tensor) { + %19688 = shape.broadcast %5545, %5542 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5544, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5539, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5548 = stablehlo.dot %5547, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_1958 = tensor.dim %5520, %c0 : tensor + %5549 = arith.index_cast %dim_1958 : index to i64 + %from_elements_1959 = tensor.from_elements %5549, %c1_i64 : tensor<2xi64> + %5550 = stablehlo.dynamic_reshape %5520, %from_elements_1959 : (tensor, tensor<2xi64>) -> tensor + %dim_1960 = tensor.dim %5517, %c0 : tensor + %5551 = arith.index_cast %dim_1960 : index to i64 + %from_elements_1961 = tensor.from_elements %5551, %c1_i64 : tensor<2xi64> + %5552 = stablehlo.dynamic_reshape %5517, %from_elements_1961 : (tensor, tensor<2xi64>) -> tensor + %5553 = stablehlo.concatenate %5550, %5552, dim = 1 : (tensor, tensor) -> tensor + %5554 = "stablehlo.gather"(%5107, %5553) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5555 = shape.shape_of %5548 : tensor -> tensor<2xindex> + %5556 = shape.shape_of %5554 : tensor -> tensor<2xindex> + %5557 = shape.cstr_broadcastable %5555, %5556 : tensor<2xindex>, tensor<2xindex> + %5558 = shape.assuming %5557 -> (tensor) { + %19688 = shape.broadcast %5555, %5556 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5548, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5554, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5559 = shape.shape_of %5558 : tensor -> tensor<2xindex> + %5560 = stablehlo.dynamic_broadcast_in_dim %5558, %5559, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5561 = stablehlo.dynamic_broadcast_in_dim %213, %5559, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5562 = stablehlo.multiply %5560, %5561 : tensor + %dim_1962 = tensor.dim %5522, %c0 : tensor + %5563 = arith.index_cast %dim_1962 : index to i64 + %dim_1963 = tensor.dim %5558, %c0 : tensor + %5564 = arith.index_cast %dim_1963 : index to i64 + %5565 = arith.maxsi %5563, %5564 : i64 + %5566 = arith.index_cast %5565 : i64 to index + %from_elements_1964 = tensor.from_elements %5566, %c4096 : tensor<2xindex> + %5567 = stablehlo.dynamic_broadcast_in_dim %5522, %from_elements_1964, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1965 = tensor.dim %5567, %c0 : tensor + %5568 = arith.index_cast %dim_1965 : index to i64 + %from_elements_1966 = tensor.from_elements %5568, %c4096_i64 : tensor<2xi64> + %5569 = stablehlo.real_dynamic_slice %5562, %c_22, %from_elements_1966, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1967 = tensor.from_elements %5568, %c4096_i64, %c1_i64 : tensor<3xi64> + %5570 = stablehlo.dynamic_reshape %5567, %from_elements_1967 : (tensor, tensor<3xi64>) -> tensor + %5571 = stablehlo.dynamic_iota %from_elements_1967, dim = 1 : (tensor<3xi64>) -> tensor + %5572 = stablehlo.concatenate %5570, %5571, dim = 2 : (tensor, tensor) -> tensor + %5573 = "stablehlo.scatter"(%5510, %5572, %5569) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5574 = stablehlo.reshape %5573 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %5575 = stablehlo.add %5040, %5574 : tensor<3x1x4096xf32> + %5576 = stablehlo.broadcast_in_dim %5575, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %5577 = stablehlo.power %5576, %15 : tensor<3x1x4096xf32> + %5578 = stablehlo.reduce(%5577 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %5579 = stablehlo.reshape %5578 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %5580 = stablehlo.broadcast_in_dim %5579, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %5581 = stablehlo.divide %5580, %21 : tensor<3x1x1xf32> + %5582 = stablehlo.broadcast_in_dim %5581, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %5583 = stablehlo.add %5582, %25 : tensor<3x1x1xf32> + %5584 = stablehlo.rsqrt %5583 : tensor<3x1x1xf32> + %5585 = stablehlo.broadcast_in_dim %5584, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %5586 = stablehlo.multiply %5576, %5585 : tensor<3x1x4096xf32> + %5587 = stablehlo.broadcast_in_dim %5586, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %5588 = stablehlo.multiply %5587, %31 : tensor<3x1x4096xf32> + %5589 = stablehlo.reshape %5588 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %5590 = stablehlo.dot %5589, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %5591 = stablehlo.reshape %5590 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %5592 = stablehlo.dot %5589, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %5593 = stablehlo.reshape %5592 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %5594 = stablehlo.reshape %5591 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %5595 = stablehlo.transpose %5594, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %5596 = stablehlo.reshape %5593 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %5597 = stablehlo.transpose %5596, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %5598 = stablehlo.slice %arg18 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %5599 = stablehlo.slice %arg19 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %5600 = "stablehlo.gather"(%5598, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %5601 = stablehlo.reshape %5600 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %5602 = "stablehlo.gather"(%5599, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %5603 = stablehlo.reshape %5602 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %5604 = stablehlo.broadcast_in_dim %5595, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %5605 = stablehlo.broadcast_in_dim %5601, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %5606 = stablehlo.multiply %5604, %5605 : tensor<3x32x1x128xf32> + %5607 = stablehlo.slice %5595 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %5608 = stablehlo.slice %5595 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %5609 = stablehlo.negate %5608 : tensor<3x32x1x64xf32> + %5610 = stablehlo.concatenate %5609, %5607, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %5611 = stablehlo.broadcast_in_dim %5610, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %5612 = stablehlo.broadcast_in_dim %5603, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %5613 = stablehlo.multiply %5611, %5612 : tensor<3x32x1x128xf32> + %5614 = stablehlo.add %5606, %5613 : tensor<3x32x1x128xf32> + %5615 = stablehlo.broadcast_in_dim %5597, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %5616 = stablehlo.broadcast_in_dim %5601, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %5617 = stablehlo.multiply %5615, %5616 : tensor<3x8x1x128xf32> + %5618 = stablehlo.slice %5597 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %5619 = stablehlo.slice %5597 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %5620 = stablehlo.negate %5619 : tensor<3x8x1x64xf32> + %5621 = stablehlo.concatenate %5620, %5618, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %5622 = stablehlo.broadcast_in_dim %5621, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %5623 = stablehlo.broadcast_in_dim %5603, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %5624 = stablehlo.multiply %5622, %5623 : tensor<3x8x1x128xf32> + %5625 = stablehlo.add %5617, %5624 : tensor<3x8x1x128xf32> + %5626 = stablehlo.concatenate %arg83, %5625, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %5627 = stablehlo.concatenate %arg84, %5597, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %5628 = stablehlo.reshape %5626 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %5629 = stablehlo.broadcast_in_dim %5628, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %5630 = stablehlo.reshape %5629 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %5631 = stablehlo.reshape %5627 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %5632 = stablehlo.broadcast_in_dim %5631, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %5633 = stablehlo.reshape %5632 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %5634 = stablehlo.transpose %5630, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %5635 = stablehlo.reshape %5614 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %5636 = stablehlo.reshape %5634 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %5637 = stablehlo.broadcast_in_dim %5636, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %5638 = stablehlo.dot_general %5635, %5637, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %5639 = stablehlo.reshape %5638 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %5640 = stablehlo.broadcast_in_dim %5639, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %5641 = stablehlo.divide %5640, %89 : tensor<3x32x1x8xf32> + %5642 = stablehlo.custom_call @byteir.softmax(%5641) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %5643 = stablehlo.reshape %5642 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %5644 = stablehlo.reshape %5633 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %5645 = stablehlo.broadcast_in_dim %5644, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %5646 = stablehlo.dot_general %5643, %5645, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %5647 = stablehlo.reshape %5646 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %5648 = stablehlo.transpose %5647, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %5649 = stablehlo.reshape %5648 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %5650 = stablehlo.reshape %5649 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %5651 = stablehlo.dot %5650, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %5652 = stablehlo.reshape %5651 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %5653 = stablehlo.add %5575, %5652 : tensor<3x1x4096xf32> + %5654 = stablehlo.broadcast_in_dim %5653, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %5655 = stablehlo.power %5654, %15 : tensor<3x1x4096xf32> + %5656 = stablehlo.reduce(%5655 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %5657 = stablehlo.reshape %5656 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %5658 = stablehlo.broadcast_in_dim %5657, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %5659 = stablehlo.divide %5658, %21 : tensor<3x1x1xf32> + %5660 = stablehlo.broadcast_in_dim %5659, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %5661 = stablehlo.add %5660, %25 : tensor<3x1x1xf32> + %5662 = stablehlo.rsqrt %5661 : tensor<3x1x1xf32> + %5663 = stablehlo.broadcast_in_dim %5662, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %5664 = stablehlo.multiply %5654, %5663 : tensor<3x1x4096xf32> + %5665 = stablehlo.broadcast_in_dim %5664, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %5666 = stablehlo.multiply %5665, %31 : tensor<3x1x4096xf32> + %5667 = stablehlo.reshape %5666 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %5668 = stablehlo.dot %5667, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %5669 = stablehlo.custom_call @byteir.softmax(%5668) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %5670:2 = stablehlo.custom_call @byteir.top_k(%5669) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %5671 = stablehlo.reduce(%5670#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %5672 = stablehlo.reshape %5671 : (tensor<3xf32>) -> tensor<3x1xf32> + %5673 = stablehlo.broadcast_in_dim %5670#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %5674 = stablehlo.broadcast_in_dim %5672, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %5675 = stablehlo.divide %5673, %5674 : tensor<3x2xf32> + %5676 = stablehlo.reshape %5670#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %5677 = stablehlo.broadcast_in_dim %5676, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %5678 = stablehlo.compare EQ, %5677, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %5679 = stablehlo.convert %5678 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %5680 = stablehlo.transpose %5679, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %5681 = stablehlo.slice %5680 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5682 = stablehlo.reshape %5681 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5683 = stablehlo.custom_call @byteir.non_zero(%5682) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1968 = tensor.dim %5683, %c0 : tensor + %5684 = arith.index_cast %dim_1968 : index to i64 + %from_elements_1969 = tensor.from_elements %5684, %c1_i64 : tensor<2xi64> + %5685 = stablehlo.real_dynamic_slice %5683, %c_22, %from_elements_1969, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1970 = tensor.dim %5685, %c0 : tensor + %5686 = arith.index_cast %dim_1970 : index to i64 + %from_elements_1971 = tensor.from_elements %5686 : tensor<1xi64> + %5687 = stablehlo.dynamic_reshape %5685, %from_elements_1971 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1972 = tensor.from_elements %5684, %c2_i64 : tensor<2xi64> + %5688 = stablehlo.real_dynamic_slice %5683, %c_24, %from_elements_1972, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1973 = tensor.dim %5688, %c0 : tensor + %5689 = arith.index_cast %dim_1973 : index to i64 + %from_elements_1974 = tensor.from_elements %5689 : tensor<1xi64> + %5690 = stablehlo.dynamic_reshape %5688, %from_elements_1974 : (tensor, tensor<1xi64>) -> tensor + %5691 = stablehlo.reshape %5667 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_1975 = tensor.dim %5690, %c0 : tensor + %5692 = arith.index_cast %dim_1975 : index to i64 + %from_elements_1976 = tensor.from_elements %5692, %c1_i64 : tensor<2xi64> + %5693 = stablehlo.dynamic_reshape %5690, %from_elements_1976 : (tensor, tensor<2xi64>) -> tensor + %dim_1977 = tensor.dim %5693, %c0 : tensor + %5694 = arith.index_cast %dim_1977 : index to i64 + %from_elements_1978 = tensor.from_elements %c1_i64, %5694, %c4096_i64 : tensor<3xi64> + %5695 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_1978, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1979 = tensor.dim %5695, %c1 : tensor<1x?x4096xi64> + %5696 = arith.index_cast %dim_1979 : index to i64 + %from_elements_1980 = tensor.from_elements %c1_i64, %5696, %c4096_i64, %c1_i64 : tensor<4xi64> + %5697 = stablehlo.dynamic_reshape %5695, %from_elements_1980 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5698 = stablehlo.dynamic_broadcast_in_dim %5693, %from_elements_1978, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1981 = tensor.dim %5698, %c1 : tensor<1x?x4096xi64> + %5699 = arith.index_cast %dim_1981 : index to i64 + %from_elements_1982 = tensor.from_elements %c1_i64, %5699, %c4096_i64, %c1_i64 : tensor<4xi64> + %5700 = stablehlo.dynamic_reshape %5698, %from_elements_1982 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5701 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_1978, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_1983 = tensor.dim %5701, %c1 : tensor<1x?x4096xi64> + %5702 = arith.index_cast %dim_1983 : index to i64 + %from_elements_1984 = tensor.from_elements %c1_i64, %5702, %c4096_i64, %c1_i64 : tensor<4xi64> + %5703 = stablehlo.dynamic_reshape %5701, %from_elements_1984 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5704 = stablehlo.concatenate %5697, %5700, %5703, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5705 = "stablehlo.gather"(%5691, %5704) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5706 = shape.shape_of %5705 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5707 = shape.num_elements %5706 : tensor<3xindex> -> index + %5708 = stablehlo.compute_reshape_shape %5707, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5709 = stablehlo.dynamic_reshape %5705, %5708 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5710 = stablehlo.dot %5709, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5711 = stablehlo.logistic %5710 : tensor + %5712 = shape.shape_of %5711 : tensor -> tensor<2xindex> + %5713 = shape.shape_of %5710 : tensor -> tensor<2xindex> + %5714 = shape.cstr_broadcastable %5712, %5713 : tensor<2xindex>, tensor<2xindex> + %5715 = shape.assuming %5714 -> (tensor) { + %19688 = shape.broadcast %5712, %5713 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5711, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5710, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5716 = shape.shape_of %5715 : tensor -> tensor<2xindex> + %5717 = shape.cstr_broadcastable %5716, %5713 : tensor<2xindex>, tensor<2xindex> + %5718 = shape.assuming %5717 -> (tensor) { + %19688 = shape.broadcast %5716, %5713 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5715, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5710, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5719 = stablehlo.dot %5718, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %5720 = stablehlo.reshape %5675 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_1985 = tensor.dim %5690, %c0 : tensor + %5721 = arith.index_cast %dim_1985 : index to i64 + %from_elements_1986 = tensor.from_elements %5721, %c1_i64 : tensor<2xi64> + %5722 = stablehlo.dynamic_reshape %5690, %from_elements_1986 : (tensor, tensor<2xi64>) -> tensor + %dim_1987 = tensor.dim %5687, %c0 : tensor + %5723 = arith.index_cast %dim_1987 : index to i64 + %from_elements_1988 = tensor.from_elements %5723, %c1_i64 : tensor<2xi64> + %5724 = stablehlo.dynamic_reshape %5687, %from_elements_1988 : (tensor, tensor<2xi64>) -> tensor + %5725 = stablehlo.concatenate %5722, %5724, dim = 1 : (tensor, tensor) -> tensor + %5726 = "stablehlo.gather"(%5720, %5725) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5727 = shape.shape_of %5719 : tensor -> tensor<2xindex> + %5728 = shape.shape_of %5726 : tensor -> tensor<2xindex> + %5729 = shape.cstr_broadcastable %5727, %5728 : tensor<2xindex>, tensor<2xindex> + %5730 = shape.assuming %5729 -> (tensor) { + %19688 = shape.broadcast %5727, %5728 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5719, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5726, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5731 = shape.shape_of %5730 : tensor -> tensor<2xindex> + %5732 = stablehlo.dynamic_broadcast_in_dim %5730, %5731, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5733 = stablehlo.dynamic_broadcast_in_dim %213, %5731, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5734 = stablehlo.multiply %5732, %5733 : tensor + %dim_1989 = tensor.dim %5693, %c0 : tensor + %5735 = arith.index_cast %dim_1989 : index to i64 + %dim_1990 = tensor.dim %5730, %c0 : tensor + %5736 = arith.index_cast %dim_1990 : index to i64 + %5737 = arith.maxsi %5735, %5736 : i64 + %5738 = arith.index_cast %5737 : i64 to index + %from_elements_1991 = tensor.from_elements %5738, %c4096 : tensor<2xindex> + %5739 = stablehlo.dynamic_broadcast_in_dim %5693, %from_elements_1991, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_1992 = tensor.dim %5739, %c0 : tensor + %5740 = arith.index_cast %dim_1992 : index to i64 + %from_elements_1993 = tensor.from_elements %5740, %c4096_i64 : tensor<2xi64> + %5741 = stablehlo.real_dynamic_slice %5734, %c_22, %from_elements_1993, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_1994 = tensor.from_elements %5740, %c4096_i64, %c1_i64 : tensor<3xi64> + %5742 = stablehlo.dynamic_reshape %5739, %from_elements_1994 : (tensor, tensor<3xi64>) -> tensor + %5743 = stablehlo.dynamic_iota %from_elements_1994, dim = 1 : (tensor<3xi64>) -> tensor + %5744 = stablehlo.concatenate %5742, %5743, dim = 2 : (tensor, tensor) -> tensor + %5745 = "stablehlo.scatter"(%cst_2, %5744, %5741) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5746 = stablehlo.slice %5680 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5747 = stablehlo.reshape %5746 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5748 = stablehlo.custom_call @byteir.non_zero(%5747) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_1995 = tensor.dim %5748, %c0 : tensor + %5749 = arith.index_cast %dim_1995 : index to i64 + %from_elements_1996 = tensor.from_elements %5749, %c1_i64 : tensor<2xi64> + %5750 = stablehlo.real_dynamic_slice %5748, %c_22, %from_elements_1996, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_1997 = tensor.dim %5750, %c0 : tensor + %5751 = arith.index_cast %dim_1997 : index to i64 + %from_elements_1998 = tensor.from_elements %5751 : tensor<1xi64> + %5752 = stablehlo.dynamic_reshape %5750, %from_elements_1998 : (tensor, tensor<1xi64>) -> tensor + %from_elements_1999 = tensor.from_elements %5749, %c2_i64 : tensor<2xi64> + %5753 = stablehlo.real_dynamic_slice %5748, %c_24, %from_elements_1999, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2000 = tensor.dim %5753, %c0 : tensor + %5754 = arith.index_cast %dim_2000 : index to i64 + %from_elements_2001 = tensor.from_elements %5754 : tensor<1xi64> + %5755 = stablehlo.dynamic_reshape %5753, %from_elements_2001 : (tensor, tensor<1xi64>) -> tensor + %dim_2002 = tensor.dim %5755, %c0 : tensor + %5756 = arith.index_cast %dim_2002 : index to i64 + %from_elements_2003 = tensor.from_elements %5756, %c1_i64 : tensor<2xi64> + %5757 = stablehlo.dynamic_reshape %5755, %from_elements_2003 : (tensor, tensor<2xi64>) -> tensor + %dim_2004 = tensor.dim %5757, %c0 : tensor + %5758 = arith.index_cast %dim_2004 : index to i64 + %from_elements_2005 = tensor.from_elements %c1_i64, %5758, %c4096_i64 : tensor<3xi64> + %5759 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2005, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2006 = tensor.dim %5759, %c1 : tensor<1x?x4096xi64> + %5760 = arith.index_cast %dim_2006 : index to i64 + %from_elements_2007 = tensor.from_elements %c1_i64, %5760, %c4096_i64, %c1_i64 : tensor<4xi64> + %5761 = stablehlo.dynamic_reshape %5759, %from_elements_2007 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5762 = stablehlo.dynamic_broadcast_in_dim %5757, %from_elements_2005, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2008 = tensor.dim %5762, %c1 : tensor<1x?x4096xi64> + %5763 = arith.index_cast %dim_2008 : index to i64 + %from_elements_2009 = tensor.from_elements %c1_i64, %5763, %c4096_i64, %c1_i64 : tensor<4xi64> + %5764 = stablehlo.dynamic_reshape %5762, %from_elements_2009 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5765 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2005, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2010 = tensor.dim %5765, %c1 : tensor<1x?x4096xi64> + %5766 = arith.index_cast %dim_2010 : index to i64 + %from_elements_2011 = tensor.from_elements %c1_i64, %5766, %c4096_i64, %c1_i64 : tensor<4xi64> + %5767 = stablehlo.dynamic_reshape %5765, %from_elements_2011 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5768 = stablehlo.concatenate %5761, %5764, %5767, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5769 = "stablehlo.gather"(%5691, %5768) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5770 = shape.shape_of %5769 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5771 = shape.num_elements %5770 : tensor<3xindex> -> index + %5772 = stablehlo.compute_reshape_shape %5771, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5773 = stablehlo.dynamic_reshape %5769, %5772 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5774 = stablehlo.dot %5773, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5775 = stablehlo.logistic %5774 : tensor + %5776 = shape.shape_of %5775 : tensor -> tensor<2xindex> + %5777 = shape.shape_of %5774 : tensor -> tensor<2xindex> + %5778 = shape.cstr_broadcastable %5776, %5777 : tensor<2xindex>, tensor<2xindex> + %5779 = shape.assuming %5778 -> (tensor) { + %19688 = shape.broadcast %5776, %5777 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5775, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5774, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5780 = shape.shape_of %5779 : tensor -> tensor<2xindex> + %5781 = shape.cstr_broadcastable %5780, %5777 : tensor<2xindex>, tensor<2xindex> + %5782 = shape.assuming %5781 -> (tensor) { + %19688 = shape.broadcast %5780, %5777 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5779, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5774, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5783 = stablehlo.dot %5782, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2012 = tensor.dim %5755, %c0 : tensor + %5784 = arith.index_cast %dim_2012 : index to i64 + %from_elements_2013 = tensor.from_elements %5784, %c1_i64 : tensor<2xi64> + %5785 = stablehlo.dynamic_reshape %5755, %from_elements_2013 : (tensor, tensor<2xi64>) -> tensor + %dim_2014 = tensor.dim %5752, %c0 : tensor + %5786 = arith.index_cast %dim_2014 : index to i64 + %from_elements_2015 = tensor.from_elements %5786, %c1_i64 : tensor<2xi64> + %5787 = stablehlo.dynamic_reshape %5752, %from_elements_2015 : (tensor, tensor<2xi64>) -> tensor + %5788 = stablehlo.concatenate %5785, %5787, dim = 1 : (tensor, tensor) -> tensor + %5789 = "stablehlo.gather"(%5720, %5788) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5790 = shape.shape_of %5783 : tensor -> tensor<2xindex> + %5791 = shape.shape_of %5789 : tensor -> tensor<2xindex> + %5792 = shape.cstr_broadcastable %5790, %5791 : tensor<2xindex>, tensor<2xindex> + %5793 = shape.assuming %5792 -> (tensor) { + %19688 = shape.broadcast %5790, %5791 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5783, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5789, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5794 = shape.shape_of %5793 : tensor -> tensor<2xindex> + %5795 = stablehlo.dynamic_broadcast_in_dim %5793, %5794, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5796 = stablehlo.dynamic_broadcast_in_dim %213, %5794, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5797 = stablehlo.multiply %5795, %5796 : tensor + %dim_2016 = tensor.dim %5757, %c0 : tensor + %5798 = arith.index_cast %dim_2016 : index to i64 + %dim_2017 = tensor.dim %5793, %c0 : tensor + %5799 = arith.index_cast %dim_2017 : index to i64 + %5800 = arith.maxsi %5798, %5799 : i64 + %5801 = arith.index_cast %5800 : i64 to index + %from_elements_2018 = tensor.from_elements %5801, %c4096 : tensor<2xindex> + %5802 = stablehlo.dynamic_broadcast_in_dim %5757, %from_elements_2018, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2019 = tensor.dim %5802, %c0 : tensor + %5803 = arith.index_cast %dim_2019 : index to i64 + %from_elements_2020 = tensor.from_elements %5803, %c4096_i64 : tensor<2xi64> + %5804 = stablehlo.real_dynamic_slice %5797, %c_22, %from_elements_2020, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2021 = tensor.from_elements %5803, %c4096_i64, %c1_i64 : tensor<3xi64> + %5805 = stablehlo.dynamic_reshape %5802, %from_elements_2021 : (tensor, tensor<3xi64>) -> tensor + %5806 = stablehlo.dynamic_iota %from_elements_2021, dim = 1 : (tensor<3xi64>) -> tensor + %5807 = stablehlo.concatenate %5805, %5806, dim = 2 : (tensor, tensor) -> tensor + %5808 = "stablehlo.scatter"(%5745, %5807, %5804) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5809 = stablehlo.slice %5680 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5810 = stablehlo.reshape %5809 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5811 = stablehlo.custom_call @byteir.non_zero(%5810) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2022 = tensor.dim %5811, %c0 : tensor + %5812 = arith.index_cast %dim_2022 : index to i64 + %from_elements_2023 = tensor.from_elements %5812, %c1_i64 : tensor<2xi64> + %5813 = stablehlo.real_dynamic_slice %5811, %c_22, %from_elements_2023, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2024 = tensor.dim %5813, %c0 : tensor + %5814 = arith.index_cast %dim_2024 : index to i64 + %from_elements_2025 = tensor.from_elements %5814 : tensor<1xi64> + %5815 = stablehlo.dynamic_reshape %5813, %from_elements_2025 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2026 = tensor.from_elements %5812, %c2_i64 : tensor<2xi64> + %5816 = stablehlo.real_dynamic_slice %5811, %c_24, %from_elements_2026, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2027 = tensor.dim %5816, %c0 : tensor + %5817 = arith.index_cast %dim_2027 : index to i64 + %from_elements_2028 = tensor.from_elements %5817 : tensor<1xi64> + %5818 = stablehlo.dynamic_reshape %5816, %from_elements_2028 : (tensor, tensor<1xi64>) -> tensor + %dim_2029 = tensor.dim %5818, %c0 : tensor + %5819 = arith.index_cast %dim_2029 : index to i64 + %from_elements_2030 = tensor.from_elements %5819, %c1_i64 : tensor<2xi64> + %5820 = stablehlo.dynamic_reshape %5818, %from_elements_2030 : (tensor, tensor<2xi64>) -> tensor + %dim_2031 = tensor.dim %5820, %c0 : tensor + %5821 = arith.index_cast %dim_2031 : index to i64 + %from_elements_2032 = tensor.from_elements %c1_i64, %5821, %c4096_i64 : tensor<3xi64> + %5822 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2032, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2033 = tensor.dim %5822, %c1 : tensor<1x?x4096xi64> + %5823 = arith.index_cast %dim_2033 : index to i64 + %from_elements_2034 = tensor.from_elements %c1_i64, %5823, %c4096_i64, %c1_i64 : tensor<4xi64> + %5824 = stablehlo.dynamic_reshape %5822, %from_elements_2034 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5825 = stablehlo.dynamic_broadcast_in_dim %5820, %from_elements_2032, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2035 = tensor.dim %5825, %c1 : tensor<1x?x4096xi64> + %5826 = arith.index_cast %dim_2035 : index to i64 + %from_elements_2036 = tensor.from_elements %c1_i64, %5826, %c4096_i64, %c1_i64 : tensor<4xi64> + %5827 = stablehlo.dynamic_reshape %5825, %from_elements_2036 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5828 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2032, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2037 = tensor.dim %5828, %c1 : tensor<1x?x4096xi64> + %5829 = arith.index_cast %dim_2037 : index to i64 + %from_elements_2038 = tensor.from_elements %c1_i64, %5829, %c4096_i64, %c1_i64 : tensor<4xi64> + %5830 = stablehlo.dynamic_reshape %5828, %from_elements_2038 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5831 = stablehlo.concatenate %5824, %5827, %5830, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5832 = "stablehlo.gather"(%5691, %5831) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5833 = shape.shape_of %5832 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5834 = shape.num_elements %5833 : tensor<3xindex> -> index + %5835 = stablehlo.compute_reshape_shape %5834, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5836 = stablehlo.dynamic_reshape %5832, %5835 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5837 = stablehlo.dot %5836, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5838 = stablehlo.logistic %5837 : tensor + %5839 = shape.shape_of %5838 : tensor -> tensor<2xindex> + %5840 = shape.shape_of %5837 : tensor -> tensor<2xindex> + %5841 = shape.cstr_broadcastable %5839, %5840 : tensor<2xindex>, tensor<2xindex> + %5842 = shape.assuming %5841 -> (tensor) { + %19688 = shape.broadcast %5839, %5840 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5838, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5837, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5843 = shape.shape_of %5842 : tensor -> tensor<2xindex> + %5844 = shape.cstr_broadcastable %5843, %5840 : tensor<2xindex>, tensor<2xindex> + %5845 = shape.assuming %5844 -> (tensor) { + %19688 = shape.broadcast %5843, %5840 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5842, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5837, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5846 = stablehlo.dot %5845, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2039 = tensor.dim %5818, %c0 : tensor + %5847 = arith.index_cast %dim_2039 : index to i64 + %from_elements_2040 = tensor.from_elements %5847, %c1_i64 : tensor<2xi64> + %5848 = stablehlo.dynamic_reshape %5818, %from_elements_2040 : (tensor, tensor<2xi64>) -> tensor + %dim_2041 = tensor.dim %5815, %c0 : tensor + %5849 = arith.index_cast %dim_2041 : index to i64 + %from_elements_2042 = tensor.from_elements %5849, %c1_i64 : tensor<2xi64> + %5850 = stablehlo.dynamic_reshape %5815, %from_elements_2042 : (tensor, tensor<2xi64>) -> tensor + %5851 = stablehlo.concatenate %5848, %5850, dim = 1 : (tensor, tensor) -> tensor + %5852 = "stablehlo.gather"(%5720, %5851) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5853 = shape.shape_of %5846 : tensor -> tensor<2xindex> + %5854 = shape.shape_of %5852 : tensor -> tensor<2xindex> + %5855 = shape.cstr_broadcastable %5853, %5854 : tensor<2xindex>, tensor<2xindex> + %5856 = shape.assuming %5855 -> (tensor) { + %19688 = shape.broadcast %5853, %5854 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5846, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5852, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5857 = shape.shape_of %5856 : tensor -> tensor<2xindex> + %5858 = stablehlo.dynamic_broadcast_in_dim %5856, %5857, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5859 = stablehlo.dynamic_broadcast_in_dim %213, %5857, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5860 = stablehlo.multiply %5858, %5859 : tensor + %dim_2043 = tensor.dim %5820, %c0 : tensor + %5861 = arith.index_cast %dim_2043 : index to i64 + %dim_2044 = tensor.dim %5856, %c0 : tensor + %5862 = arith.index_cast %dim_2044 : index to i64 + %5863 = arith.maxsi %5861, %5862 : i64 + %5864 = arith.index_cast %5863 : i64 to index + %from_elements_2045 = tensor.from_elements %5864, %c4096 : tensor<2xindex> + %5865 = stablehlo.dynamic_broadcast_in_dim %5820, %from_elements_2045, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2046 = tensor.dim %5865, %c0 : tensor + %5866 = arith.index_cast %dim_2046 : index to i64 + %from_elements_2047 = tensor.from_elements %5866, %c4096_i64 : tensor<2xi64> + %5867 = stablehlo.real_dynamic_slice %5860, %c_22, %from_elements_2047, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2048 = tensor.from_elements %5866, %c4096_i64, %c1_i64 : tensor<3xi64> + %5868 = stablehlo.dynamic_reshape %5865, %from_elements_2048 : (tensor, tensor<3xi64>) -> tensor + %5869 = stablehlo.dynamic_iota %from_elements_2048, dim = 1 : (tensor<3xi64>) -> tensor + %5870 = stablehlo.concatenate %5868, %5869, dim = 2 : (tensor, tensor) -> tensor + %5871 = "stablehlo.scatter"(%5808, %5870, %5867) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5872 = stablehlo.slice %5680 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5873 = stablehlo.reshape %5872 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5874 = stablehlo.custom_call @byteir.non_zero(%5873) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2049 = tensor.dim %5874, %c0 : tensor + %5875 = arith.index_cast %dim_2049 : index to i64 + %from_elements_2050 = tensor.from_elements %5875, %c1_i64 : tensor<2xi64> + %5876 = stablehlo.real_dynamic_slice %5874, %c_22, %from_elements_2050, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2051 = tensor.dim %5876, %c0 : tensor + %5877 = arith.index_cast %dim_2051 : index to i64 + %from_elements_2052 = tensor.from_elements %5877 : tensor<1xi64> + %5878 = stablehlo.dynamic_reshape %5876, %from_elements_2052 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2053 = tensor.from_elements %5875, %c2_i64 : tensor<2xi64> + %5879 = stablehlo.real_dynamic_slice %5874, %c_24, %from_elements_2053, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2054 = tensor.dim %5879, %c0 : tensor + %5880 = arith.index_cast %dim_2054 : index to i64 + %from_elements_2055 = tensor.from_elements %5880 : tensor<1xi64> + %5881 = stablehlo.dynamic_reshape %5879, %from_elements_2055 : (tensor, tensor<1xi64>) -> tensor + %dim_2056 = tensor.dim %5881, %c0 : tensor + %5882 = arith.index_cast %dim_2056 : index to i64 + %from_elements_2057 = tensor.from_elements %5882, %c1_i64 : tensor<2xi64> + %5883 = stablehlo.dynamic_reshape %5881, %from_elements_2057 : (tensor, tensor<2xi64>) -> tensor + %dim_2058 = tensor.dim %5883, %c0 : tensor + %5884 = arith.index_cast %dim_2058 : index to i64 + %from_elements_2059 = tensor.from_elements %c1_i64, %5884, %c4096_i64 : tensor<3xi64> + %5885 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2059, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2060 = tensor.dim %5885, %c1 : tensor<1x?x4096xi64> + %5886 = arith.index_cast %dim_2060 : index to i64 + %from_elements_2061 = tensor.from_elements %c1_i64, %5886, %c4096_i64, %c1_i64 : tensor<4xi64> + %5887 = stablehlo.dynamic_reshape %5885, %from_elements_2061 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5888 = stablehlo.dynamic_broadcast_in_dim %5883, %from_elements_2059, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2062 = tensor.dim %5888, %c1 : tensor<1x?x4096xi64> + %5889 = arith.index_cast %dim_2062 : index to i64 + %from_elements_2063 = tensor.from_elements %c1_i64, %5889, %c4096_i64, %c1_i64 : tensor<4xi64> + %5890 = stablehlo.dynamic_reshape %5888, %from_elements_2063 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5891 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2059, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2064 = tensor.dim %5891, %c1 : tensor<1x?x4096xi64> + %5892 = arith.index_cast %dim_2064 : index to i64 + %from_elements_2065 = tensor.from_elements %c1_i64, %5892, %c4096_i64, %c1_i64 : tensor<4xi64> + %5893 = stablehlo.dynamic_reshape %5891, %from_elements_2065 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5894 = stablehlo.concatenate %5887, %5890, %5893, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5895 = "stablehlo.gather"(%5691, %5894) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5896 = shape.shape_of %5895 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5897 = shape.num_elements %5896 : tensor<3xindex> -> index + %5898 = stablehlo.compute_reshape_shape %5897, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5899 = stablehlo.dynamic_reshape %5895, %5898 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5900 = stablehlo.dot %5899, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5901 = stablehlo.logistic %5900 : tensor + %5902 = shape.shape_of %5901 : tensor -> tensor<2xindex> + %5903 = shape.shape_of %5900 : tensor -> tensor<2xindex> + %5904 = shape.cstr_broadcastable %5902, %5903 : tensor<2xindex>, tensor<2xindex> + %5905 = shape.assuming %5904 -> (tensor) { + %19688 = shape.broadcast %5902, %5903 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5901, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5900, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5906 = shape.shape_of %5905 : tensor -> tensor<2xindex> + %5907 = shape.cstr_broadcastable %5906, %5903 : tensor<2xindex>, tensor<2xindex> + %5908 = shape.assuming %5907 -> (tensor) { + %19688 = shape.broadcast %5906, %5903 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5905, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5900, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5909 = stablehlo.dot %5908, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2066 = tensor.dim %5881, %c0 : tensor + %5910 = arith.index_cast %dim_2066 : index to i64 + %from_elements_2067 = tensor.from_elements %5910, %c1_i64 : tensor<2xi64> + %5911 = stablehlo.dynamic_reshape %5881, %from_elements_2067 : (tensor, tensor<2xi64>) -> tensor + %dim_2068 = tensor.dim %5878, %c0 : tensor + %5912 = arith.index_cast %dim_2068 : index to i64 + %from_elements_2069 = tensor.from_elements %5912, %c1_i64 : tensor<2xi64> + %5913 = stablehlo.dynamic_reshape %5878, %from_elements_2069 : (tensor, tensor<2xi64>) -> tensor + %5914 = stablehlo.concatenate %5911, %5913, dim = 1 : (tensor, tensor) -> tensor + %5915 = "stablehlo.gather"(%5720, %5914) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5916 = shape.shape_of %5909 : tensor -> tensor<2xindex> + %5917 = shape.shape_of %5915 : tensor -> tensor<2xindex> + %5918 = shape.cstr_broadcastable %5916, %5917 : tensor<2xindex>, tensor<2xindex> + %5919 = shape.assuming %5918 -> (tensor) { + %19688 = shape.broadcast %5916, %5917 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5909, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5915, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5920 = shape.shape_of %5919 : tensor -> tensor<2xindex> + %5921 = stablehlo.dynamic_broadcast_in_dim %5919, %5920, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5922 = stablehlo.dynamic_broadcast_in_dim %213, %5920, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5923 = stablehlo.multiply %5921, %5922 : tensor + %dim_2070 = tensor.dim %5883, %c0 : tensor + %5924 = arith.index_cast %dim_2070 : index to i64 + %dim_2071 = tensor.dim %5919, %c0 : tensor + %5925 = arith.index_cast %dim_2071 : index to i64 + %5926 = arith.maxsi %5924, %5925 : i64 + %5927 = arith.index_cast %5926 : i64 to index + %from_elements_2072 = tensor.from_elements %5927, %c4096 : tensor<2xindex> + %5928 = stablehlo.dynamic_broadcast_in_dim %5883, %from_elements_2072, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2073 = tensor.dim %5928, %c0 : tensor + %5929 = arith.index_cast %dim_2073 : index to i64 + %from_elements_2074 = tensor.from_elements %5929, %c4096_i64 : tensor<2xi64> + %5930 = stablehlo.real_dynamic_slice %5923, %c_22, %from_elements_2074, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2075 = tensor.from_elements %5929, %c4096_i64, %c1_i64 : tensor<3xi64> + %5931 = stablehlo.dynamic_reshape %5928, %from_elements_2075 : (tensor, tensor<3xi64>) -> tensor + %5932 = stablehlo.dynamic_iota %from_elements_2075, dim = 1 : (tensor<3xi64>) -> tensor + %5933 = stablehlo.concatenate %5931, %5932, dim = 2 : (tensor, tensor) -> tensor + %5934 = "stablehlo.scatter"(%5871, %5933, %5930) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5935 = stablehlo.slice %5680 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5936 = stablehlo.reshape %5935 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %5937 = stablehlo.custom_call @byteir.non_zero(%5936) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2076 = tensor.dim %5937, %c0 : tensor + %5938 = arith.index_cast %dim_2076 : index to i64 + %from_elements_2077 = tensor.from_elements %5938, %c1_i64 : tensor<2xi64> + %5939 = stablehlo.real_dynamic_slice %5937, %c_22, %from_elements_2077, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2078 = tensor.dim %5939, %c0 : tensor + %5940 = arith.index_cast %dim_2078 : index to i64 + %from_elements_2079 = tensor.from_elements %5940 : tensor<1xi64> + %5941 = stablehlo.dynamic_reshape %5939, %from_elements_2079 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2080 = tensor.from_elements %5938, %c2_i64 : tensor<2xi64> + %5942 = stablehlo.real_dynamic_slice %5937, %c_24, %from_elements_2080, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2081 = tensor.dim %5942, %c0 : tensor + %5943 = arith.index_cast %dim_2081 : index to i64 + %from_elements_2082 = tensor.from_elements %5943 : tensor<1xi64> + %5944 = stablehlo.dynamic_reshape %5942, %from_elements_2082 : (tensor, tensor<1xi64>) -> tensor + %dim_2083 = tensor.dim %5944, %c0 : tensor + %5945 = arith.index_cast %dim_2083 : index to i64 + %from_elements_2084 = tensor.from_elements %5945, %c1_i64 : tensor<2xi64> + %5946 = stablehlo.dynamic_reshape %5944, %from_elements_2084 : (tensor, tensor<2xi64>) -> tensor + %dim_2085 = tensor.dim %5946, %c0 : tensor + %5947 = arith.index_cast %dim_2085 : index to i64 + %from_elements_2086 = tensor.from_elements %c1_i64, %5947, %c4096_i64 : tensor<3xi64> + %5948 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2086, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2087 = tensor.dim %5948, %c1 : tensor<1x?x4096xi64> + %5949 = arith.index_cast %dim_2087 : index to i64 + %from_elements_2088 = tensor.from_elements %c1_i64, %5949, %c4096_i64, %c1_i64 : tensor<4xi64> + %5950 = stablehlo.dynamic_reshape %5948, %from_elements_2088 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5951 = stablehlo.dynamic_broadcast_in_dim %5946, %from_elements_2086, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2089 = tensor.dim %5951, %c1 : tensor<1x?x4096xi64> + %5952 = arith.index_cast %dim_2089 : index to i64 + %from_elements_2090 = tensor.from_elements %c1_i64, %5952, %c4096_i64, %c1_i64 : tensor<4xi64> + %5953 = stablehlo.dynamic_reshape %5951, %from_elements_2090 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5954 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2086, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2091 = tensor.dim %5954, %c1 : tensor<1x?x4096xi64> + %5955 = arith.index_cast %dim_2091 : index to i64 + %from_elements_2092 = tensor.from_elements %c1_i64, %5955, %c4096_i64, %c1_i64 : tensor<4xi64> + %5956 = stablehlo.dynamic_reshape %5954, %from_elements_2092 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %5957 = stablehlo.concatenate %5950, %5953, %5956, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %5958 = "stablehlo.gather"(%5691, %5957) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %5959 = shape.shape_of %5958 : tensor<1x?x4096xf32> -> tensor<3xindex> + %5960 = shape.num_elements %5959 : tensor<3xindex> -> index + %5961 = stablehlo.compute_reshape_shape %5960, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %5962 = stablehlo.dynamic_reshape %5958, %5961 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %5963 = stablehlo.dot %5962, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %5964 = stablehlo.logistic %5963 : tensor + %5965 = shape.shape_of %5964 : tensor -> tensor<2xindex> + %5966 = shape.shape_of %5963 : tensor -> tensor<2xindex> + %5967 = shape.cstr_broadcastable %5965, %5966 : tensor<2xindex>, tensor<2xindex> + %5968 = shape.assuming %5967 -> (tensor) { + %19688 = shape.broadcast %5965, %5966 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5964, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5963, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5969 = shape.shape_of %5968 : tensor -> tensor<2xindex> + %5970 = shape.cstr_broadcastable %5969, %5966 : tensor<2xindex>, tensor<2xindex> + %5971 = shape.assuming %5970 -> (tensor) { + %19688 = shape.broadcast %5969, %5966 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5968, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5963, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5972 = stablehlo.dot %5971, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2093 = tensor.dim %5944, %c0 : tensor + %5973 = arith.index_cast %dim_2093 : index to i64 + %from_elements_2094 = tensor.from_elements %5973, %c1_i64 : tensor<2xi64> + %5974 = stablehlo.dynamic_reshape %5944, %from_elements_2094 : (tensor, tensor<2xi64>) -> tensor + %dim_2095 = tensor.dim %5941, %c0 : tensor + %5975 = arith.index_cast %dim_2095 : index to i64 + %from_elements_2096 = tensor.from_elements %5975, %c1_i64 : tensor<2xi64> + %5976 = stablehlo.dynamic_reshape %5941, %from_elements_2096 : (tensor, tensor<2xi64>) -> tensor + %5977 = stablehlo.concatenate %5974, %5976, dim = 1 : (tensor, tensor) -> tensor + %5978 = "stablehlo.gather"(%5720, %5977) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %5979 = shape.shape_of %5972 : tensor -> tensor<2xindex> + %5980 = shape.shape_of %5978 : tensor -> tensor<2xindex> + %5981 = shape.cstr_broadcastable %5979, %5980 : tensor<2xindex>, tensor<2xindex> + %5982 = shape.assuming %5981 -> (tensor) { + %19688 = shape.broadcast %5979, %5980 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %5972, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %5978, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %5983 = shape.shape_of %5982 : tensor -> tensor<2xindex> + %5984 = stablehlo.dynamic_broadcast_in_dim %5982, %5983, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %5985 = stablehlo.dynamic_broadcast_in_dim %213, %5983, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5986 = stablehlo.multiply %5984, %5985 : tensor + %dim_2097 = tensor.dim %5946, %c0 : tensor + %5987 = arith.index_cast %dim_2097 : index to i64 + %dim_2098 = tensor.dim %5982, %c0 : tensor + %5988 = arith.index_cast %dim_2098 : index to i64 + %5989 = arith.maxsi %5987, %5988 : i64 + %5990 = arith.index_cast %5989 : i64 to index + %from_elements_2099 = tensor.from_elements %5990, %c4096 : tensor<2xindex> + %5991 = stablehlo.dynamic_broadcast_in_dim %5946, %from_elements_2099, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2100 = tensor.dim %5991, %c0 : tensor + %5992 = arith.index_cast %dim_2100 : index to i64 + %from_elements_2101 = tensor.from_elements %5992, %c4096_i64 : tensor<2xi64> + %5993 = stablehlo.real_dynamic_slice %5986, %c_22, %from_elements_2101, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2102 = tensor.from_elements %5992, %c4096_i64, %c1_i64 : tensor<3xi64> + %5994 = stablehlo.dynamic_reshape %5991, %from_elements_2102 : (tensor, tensor<3xi64>) -> tensor + %5995 = stablehlo.dynamic_iota %from_elements_2102, dim = 1 : (tensor<3xi64>) -> tensor + %5996 = stablehlo.concatenate %5994, %5995, dim = 2 : (tensor, tensor) -> tensor + %5997 = "stablehlo.scatter"(%5934, %5996, %5993) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %5998 = stablehlo.slice %5680 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %5999 = stablehlo.reshape %5998 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6000 = stablehlo.custom_call @byteir.non_zero(%5999) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2103 = tensor.dim %6000, %c0 : tensor + %6001 = arith.index_cast %dim_2103 : index to i64 + %from_elements_2104 = tensor.from_elements %6001, %c1_i64 : tensor<2xi64> + %6002 = stablehlo.real_dynamic_slice %6000, %c_22, %from_elements_2104, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2105 = tensor.dim %6002, %c0 : tensor + %6003 = arith.index_cast %dim_2105 : index to i64 + %from_elements_2106 = tensor.from_elements %6003 : tensor<1xi64> + %6004 = stablehlo.dynamic_reshape %6002, %from_elements_2106 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2107 = tensor.from_elements %6001, %c2_i64 : tensor<2xi64> + %6005 = stablehlo.real_dynamic_slice %6000, %c_24, %from_elements_2107, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2108 = tensor.dim %6005, %c0 : tensor + %6006 = arith.index_cast %dim_2108 : index to i64 + %from_elements_2109 = tensor.from_elements %6006 : tensor<1xi64> + %6007 = stablehlo.dynamic_reshape %6005, %from_elements_2109 : (tensor, tensor<1xi64>) -> tensor + %dim_2110 = tensor.dim %6007, %c0 : tensor + %6008 = arith.index_cast %dim_2110 : index to i64 + %from_elements_2111 = tensor.from_elements %6008, %c1_i64 : tensor<2xi64> + %6009 = stablehlo.dynamic_reshape %6007, %from_elements_2111 : (tensor, tensor<2xi64>) -> tensor + %dim_2112 = tensor.dim %6009, %c0 : tensor + %6010 = arith.index_cast %dim_2112 : index to i64 + %from_elements_2113 = tensor.from_elements %c1_i64, %6010, %c4096_i64 : tensor<3xi64> + %6011 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2113, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2114 = tensor.dim %6011, %c1 : tensor<1x?x4096xi64> + %6012 = arith.index_cast %dim_2114 : index to i64 + %from_elements_2115 = tensor.from_elements %c1_i64, %6012, %c4096_i64, %c1_i64 : tensor<4xi64> + %6013 = stablehlo.dynamic_reshape %6011, %from_elements_2115 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6014 = stablehlo.dynamic_broadcast_in_dim %6009, %from_elements_2113, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2116 = tensor.dim %6014, %c1 : tensor<1x?x4096xi64> + %6015 = arith.index_cast %dim_2116 : index to i64 + %from_elements_2117 = tensor.from_elements %c1_i64, %6015, %c4096_i64, %c1_i64 : tensor<4xi64> + %6016 = stablehlo.dynamic_reshape %6014, %from_elements_2117 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6017 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2113, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2118 = tensor.dim %6017, %c1 : tensor<1x?x4096xi64> + %6018 = arith.index_cast %dim_2118 : index to i64 + %from_elements_2119 = tensor.from_elements %c1_i64, %6018, %c4096_i64, %c1_i64 : tensor<4xi64> + %6019 = stablehlo.dynamic_reshape %6017, %from_elements_2119 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6020 = stablehlo.concatenate %6013, %6016, %6019, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6021 = "stablehlo.gather"(%5691, %6020) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6022 = shape.shape_of %6021 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6023 = shape.num_elements %6022 : tensor<3xindex> -> index + %6024 = stablehlo.compute_reshape_shape %6023, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6025 = stablehlo.dynamic_reshape %6021, %6024 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6026 = stablehlo.dot %6025, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6027 = stablehlo.logistic %6026 : tensor + %6028 = shape.shape_of %6027 : tensor -> tensor<2xindex> + %6029 = shape.shape_of %6026 : tensor -> tensor<2xindex> + %6030 = shape.cstr_broadcastable %6028, %6029 : tensor<2xindex>, tensor<2xindex> + %6031 = shape.assuming %6030 -> (tensor) { + %19688 = shape.broadcast %6028, %6029 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6027, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6026, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6032 = shape.shape_of %6031 : tensor -> tensor<2xindex> + %6033 = shape.cstr_broadcastable %6032, %6029 : tensor<2xindex>, tensor<2xindex> + %6034 = shape.assuming %6033 -> (tensor) { + %19688 = shape.broadcast %6032, %6029 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6031, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6026, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6035 = stablehlo.dot %6034, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2120 = tensor.dim %6007, %c0 : tensor + %6036 = arith.index_cast %dim_2120 : index to i64 + %from_elements_2121 = tensor.from_elements %6036, %c1_i64 : tensor<2xi64> + %6037 = stablehlo.dynamic_reshape %6007, %from_elements_2121 : (tensor, tensor<2xi64>) -> tensor + %dim_2122 = tensor.dim %6004, %c0 : tensor + %6038 = arith.index_cast %dim_2122 : index to i64 + %from_elements_2123 = tensor.from_elements %6038, %c1_i64 : tensor<2xi64> + %6039 = stablehlo.dynamic_reshape %6004, %from_elements_2123 : (tensor, tensor<2xi64>) -> tensor + %6040 = stablehlo.concatenate %6037, %6039, dim = 1 : (tensor, tensor) -> tensor + %6041 = "stablehlo.gather"(%5720, %6040) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6042 = shape.shape_of %6035 : tensor -> tensor<2xindex> + %6043 = shape.shape_of %6041 : tensor -> tensor<2xindex> + %6044 = shape.cstr_broadcastable %6042, %6043 : tensor<2xindex>, tensor<2xindex> + %6045 = shape.assuming %6044 -> (tensor) { + %19688 = shape.broadcast %6042, %6043 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6035, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6041, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6046 = shape.shape_of %6045 : tensor -> tensor<2xindex> + %6047 = stablehlo.dynamic_broadcast_in_dim %6045, %6046, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6048 = stablehlo.dynamic_broadcast_in_dim %213, %6046, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6049 = stablehlo.multiply %6047, %6048 : tensor + %dim_2124 = tensor.dim %6009, %c0 : tensor + %6050 = arith.index_cast %dim_2124 : index to i64 + %dim_2125 = tensor.dim %6045, %c0 : tensor + %6051 = arith.index_cast %dim_2125 : index to i64 + %6052 = arith.maxsi %6050, %6051 : i64 + %6053 = arith.index_cast %6052 : i64 to index + %from_elements_2126 = tensor.from_elements %6053, %c4096 : tensor<2xindex> + %6054 = stablehlo.dynamic_broadcast_in_dim %6009, %from_elements_2126, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2127 = tensor.dim %6054, %c0 : tensor + %6055 = arith.index_cast %dim_2127 : index to i64 + %from_elements_2128 = tensor.from_elements %6055, %c4096_i64 : tensor<2xi64> + %6056 = stablehlo.real_dynamic_slice %6049, %c_22, %from_elements_2128, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2129 = tensor.from_elements %6055, %c4096_i64, %c1_i64 : tensor<3xi64> + %6057 = stablehlo.dynamic_reshape %6054, %from_elements_2129 : (tensor, tensor<3xi64>) -> tensor + %6058 = stablehlo.dynamic_iota %from_elements_2129, dim = 1 : (tensor<3xi64>) -> tensor + %6059 = stablehlo.concatenate %6057, %6058, dim = 2 : (tensor, tensor) -> tensor + %6060 = "stablehlo.scatter"(%5997, %6059, %6056) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6061 = stablehlo.slice %5680 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6062 = stablehlo.reshape %6061 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6063 = stablehlo.custom_call @byteir.non_zero(%6062) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2130 = tensor.dim %6063, %c0 : tensor + %6064 = arith.index_cast %dim_2130 : index to i64 + %from_elements_2131 = tensor.from_elements %6064, %c1_i64 : tensor<2xi64> + %6065 = stablehlo.real_dynamic_slice %6063, %c_22, %from_elements_2131, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2132 = tensor.dim %6065, %c0 : tensor + %6066 = arith.index_cast %dim_2132 : index to i64 + %from_elements_2133 = tensor.from_elements %6066 : tensor<1xi64> + %6067 = stablehlo.dynamic_reshape %6065, %from_elements_2133 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2134 = tensor.from_elements %6064, %c2_i64 : tensor<2xi64> + %6068 = stablehlo.real_dynamic_slice %6063, %c_24, %from_elements_2134, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2135 = tensor.dim %6068, %c0 : tensor + %6069 = arith.index_cast %dim_2135 : index to i64 + %from_elements_2136 = tensor.from_elements %6069 : tensor<1xi64> + %6070 = stablehlo.dynamic_reshape %6068, %from_elements_2136 : (tensor, tensor<1xi64>) -> tensor + %dim_2137 = tensor.dim %6070, %c0 : tensor + %6071 = arith.index_cast %dim_2137 : index to i64 + %from_elements_2138 = tensor.from_elements %6071, %c1_i64 : tensor<2xi64> + %6072 = stablehlo.dynamic_reshape %6070, %from_elements_2138 : (tensor, tensor<2xi64>) -> tensor + %dim_2139 = tensor.dim %6072, %c0 : tensor + %6073 = arith.index_cast %dim_2139 : index to i64 + %from_elements_2140 = tensor.from_elements %c1_i64, %6073, %c4096_i64 : tensor<3xi64> + %6074 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2140, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2141 = tensor.dim %6074, %c1 : tensor<1x?x4096xi64> + %6075 = arith.index_cast %dim_2141 : index to i64 + %from_elements_2142 = tensor.from_elements %c1_i64, %6075, %c4096_i64, %c1_i64 : tensor<4xi64> + %6076 = stablehlo.dynamic_reshape %6074, %from_elements_2142 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6077 = stablehlo.dynamic_broadcast_in_dim %6072, %from_elements_2140, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2143 = tensor.dim %6077, %c1 : tensor<1x?x4096xi64> + %6078 = arith.index_cast %dim_2143 : index to i64 + %from_elements_2144 = tensor.from_elements %c1_i64, %6078, %c4096_i64, %c1_i64 : tensor<4xi64> + %6079 = stablehlo.dynamic_reshape %6077, %from_elements_2144 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6080 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2140, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2145 = tensor.dim %6080, %c1 : tensor<1x?x4096xi64> + %6081 = arith.index_cast %dim_2145 : index to i64 + %from_elements_2146 = tensor.from_elements %c1_i64, %6081, %c4096_i64, %c1_i64 : tensor<4xi64> + %6082 = stablehlo.dynamic_reshape %6080, %from_elements_2146 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6083 = stablehlo.concatenate %6076, %6079, %6082, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6084 = "stablehlo.gather"(%5691, %6083) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6085 = shape.shape_of %6084 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6086 = shape.num_elements %6085 : tensor<3xindex> -> index + %6087 = stablehlo.compute_reshape_shape %6086, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6088 = stablehlo.dynamic_reshape %6084, %6087 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6089 = stablehlo.dot %6088, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6090 = stablehlo.logistic %6089 : tensor + %6091 = shape.shape_of %6090 : tensor -> tensor<2xindex> + %6092 = shape.shape_of %6089 : tensor -> tensor<2xindex> + %6093 = shape.cstr_broadcastable %6091, %6092 : tensor<2xindex>, tensor<2xindex> + %6094 = shape.assuming %6093 -> (tensor) { + %19688 = shape.broadcast %6091, %6092 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6090, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6089, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6095 = shape.shape_of %6094 : tensor -> tensor<2xindex> + %6096 = shape.cstr_broadcastable %6095, %6092 : tensor<2xindex>, tensor<2xindex> + %6097 = shape.assuming %6096 -> (tensor) { + %19688 = shape.broadcast %6095, %6092 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6094, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6089, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6098 = stablehlo.dot %6097, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2147 = tensor.dim %6070, %c0 : tensor + %6099 = arith.index_cast %dim_2147 : index to i64 + %from_elements_2148 = tensor.from_elements %6099, %c1_i64 : tensor<2xi64> + %6100 = stablehlo.dynamic_reshape %6070, %from_elements_2148 : (tensor, tensor<2xi64>) -> tensor + %dim_2149 = tensor.dim %6067, %c0 : tensor + %6101 = arith.index_cast %dim_2149 : index to i64 + %from_elements_2150 = tensor.from_elements %6101, %c1_i64 : tensor<2xi64> + %6102 = stablehlo.dynamic_reshape %6067, %from_elements_2150 : (tensor, tensor<2xi64>) -> tensor + %6103 = stablehlo.concatenate %6100, %6102, dim = 1 : (tensor, tensor) -> tensor + %6104 = "stablehlo.gather"(%5720, %6103) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6105 = shape.shape_of %6098 : tensor -> tensor<2xindex> + %6106 = shape.shape_of %6104 : tensor -> tensor<2xindex> + %6107 = shape.cstr_broadcastable %6105, %6106 : tensor<2xindex>, tensor<2xindex> + %6108 = shape.assuming %6107 -> (tensor) { + %19688 = shape.broadcast %6105, %6106 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6098, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6104, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6109 = shape.shape_of %6108 : tensor -> tensor<2xindex> + %6110 = stablehlo.dynamic_broadcast_in_dim %6108, %6109, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6111 = stablehlo.dynamic_broadcast_in_dim %213, %6109, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6112 = stablehlo.multiply %6110, %6111 : tensor + %dim_2151 = tensor.dim %6072, %c0 : tensor + %6113 = arith.index_cast %dim_2151 : index to i64 + %dim_2152 = tensor.dim %6108, %c0 : tensor + %6114 = arith.index_cast %dim_2152 : index to i64 + %6115 = arith.maxsi %6113, %6114 : i64 + %6116 = arith.index_cast %6115 : i64 to index + %from_elements_2153 = tensor.from_elements %6116, %c4096 : tensor<2xindex> + %6117 = stablehlo.dynamic_broadcast_in_dim %6072, %from_elements_2153, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2154 = tensor.dim %6117, %c0 : tensor + %6118 = arith.index_cast %dim_2154 : index to i64 + %from_elements_2155 = tensor.from_elements %6118, %c4096_i64 : tensor<2xi64> + %6119 = stablehlo.real_dynamic_slice %6112, %c_22, %from_elements_2155, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2156 = tensor.from_elements %6118, %c4096_i64, %c1_i64 : tensor<3xi64> + %6120 = stablehlo.dynamic_reshape %6117, %from_elements_2156 : (tensor, tensor<3xi64>) -> tensor + %6121 = stablehlo.dynamic_iota %from_elements_2156, dim = 1 : (tensor<3xi64>) -> tensor + %6122 = stablehlo.concatenate %6120, %6121, dim = 2 : (tensor, tensor) -> tensor + %6123 = "stablehlo.scatter"(%6060, %6122, %6119) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6124 = stablehlo.slice %5680 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6125 = stablehlo.reshape %6124 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6126 = stablehlo.custom_call @byteir.non_zero(%6125) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2157 = tensor.dim %6126, %c0 : tensor + %6127 = arith.index_cast %dim_2157 : index to i64 + %from_elements_2158 = tensor.from_elements %6127, %c1_i64 : tensor<2xi64> + %6128 = stablehlo.real_dynamic_slice %6126, %c_22, %from_elements_2158, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2159 = tensor.dim %6128, %c0 : tensor + %6129 = arith.index_cast %dim_2159 : index to i64 + %from_elements_2160 = tensor.from_elements %6129 : tensor<1xi64> + %6130 = stablehlo.dynamic_reshape %6128, %from_elements_2160 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2161 = tensor.from_elements %6127, %c2_i64 : tensor<2xi64> + %6131 = stablehlo.real_dynamic_slice %6126, %c_24, %from_elements_2161, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2162 = tensor.dim %6131, %c0 : tensor + %6132 = arith.index_cast %dim_2162 : index to i64 + %from_elements_2163 = tensor.from_elements %6132 : tensor<1xi64> + %6133 = stablehlo.dynamic_reshape %6131, %from_elements_2163 : (tensor, tensor<1xi64>) -> tensor + %dim_2164 = tensor.dim %6133, %c0 : tensor + %6134 = arith.index_cast %dim_2164 : index to i64 + %from_elements_2165 = tensor.from_elements %6134, %c1_i64 : tensor<2xi64> + %6135 = stablehlo.dynamic_reshape %6133, %from_elements_2165 : (tensor, tensor<2xi64>) -> tensor + %dim_2166 = tensor.dim %6135, %c0 : tensor + %6136 = arith.index_cast %dim_2166 : index to i64 + %from_elements_2167 = tensor.from_elements %c1_i64, %6136, %c4096_i64 : tensor<3xi64> + %6137 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2167, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2168 = tensor.dim %6137, %c1 : tensor<1x?x4096xi64> + %6138 = arith.index_cast %dim_2168 : index to i64 + %from_elements_2169 = tensor.from_elements %c1_i64, %6138, %c4096_i64, %c1_i64 : tensor<4xi64> + %6139 = stablehlo.dynamic_reshape %6137, %from_elements_2169 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6140 = stablehlo.dynamic_broadcast_in_dim %6135, %from_elements_2167, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2170 = tensor.dim %6140, %c1 : tensor<1x?x4096xi64> + %6141 = arith.index_cast %dim_2170 : index to i64 + %from_elements_2171 = tensor.from_elements %c1_i64, %6141, %c4096_i64, %c1_i64 : tensor<4xi64> + %6142 = stablehlo.dynamic_reshape %6140, %from_elements_2171 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6143 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2167, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2172 = tensor.dim %6143, %c1 : tensor<1x?x4096xi64> + %6144 = arith.index_cast %dim_2172 : index to i64 + %from_elements_2173 = tensor.from_elements %c1_i64, %6144, %c4096_i64, %c1_i64 : tensor<4xi64> + %6145 = stablehlo.dynamic_reshape %6143, %from_elements_2173 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6146 = stablehlo.concatenate %6139, %6142, %6145, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6147 = "stablehlo.gather"(%5691, %6146) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6148 = shape.shape_of %6147 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6149 = shape.num_elements %6148 : tensor<3xindex> -> index + %6150 = stablehlo.compute_reshape_shape %6149, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6151 = stablehlo.dynamic_reshape %6147, %6150 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6152 = stablehlo.dot %6151, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6153 = stablehlo.logistic %6152 : tensor + %6154 = shape.shape_of %6153 : tensor -> tensor<2xindex> + %6155 = shape.shape_of %6152 : tensor -> tensor<2xindex> + %6156 = shape.cstr_broadcastable %6154, %6155 : tensor<2xindex>, tensor<2xindex> + %6157 = shape.assuming %6156 -> (tensor) { + %19688 = shape.broadcast %6154, %6155 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6153, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6152, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6158 = shape.shape_of %6157 : tensor -> tensor<2xindex> + %6159 = shape.cstr_broadcastable %6158, %6155 : tensor<2xindex>, tensor<2xindex> + %6160 = shape.assuming %6159 -> (tensor) { + %19688 = shape.broadcast %6158, %6155 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6157, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6152, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6161 = stablehlo.dot %6160, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2174 = tensor.dim %6133, %c0 : tensor + %6162 = arith.index_cast %dim_2174 : index to i64 + %from_elements_2175 = tensor.from_elements %6162, %c1_i64 : tensor<2xi64> + %6163 = stablehlo.dynamic_reshape %6133, %from_elements_2175 : (tensor, tensor<2xi64>) -> tensor + %dim_2176 = tensor.dim %6130, %c0 : tensor + %6164 = arith.index_cast %dim_2176 : index to i64 + %from_elements_2177 = tensor.from_elements %6164, %c1_i64 : tensor<2xi64> + %6165 = stablehlo.dynamic_reshape %6130, %from_elements_2177 : (tensor, tensor<2xi64>) -> tensor + %6166 = stablehlo.concatenate %6163, %6165, dim = 1 : (tensor, tensor) -> tensor + %6167 = "stablehlo.gather"(%5720, %6166) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6168 = shape.shape_of %6161 : tensor -> tensor<2xindex> + %6169 = shape.shape_of %6167 : tensor -> tensor<2xindex> + %6170 = shape.cstr_broadcastable %6168, %6169 : tensor<2xindex>, tensor<2xindex> + %6171 = shape.assuming %6170 -> (tensor) { + %19688 = shape.broadcast %6168, %6169 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6161, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6167, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6172 = shape.shape_of %6171 : tensor -> tensor<2xindex> + %6173 = stablehlo.dynamic_broadcast_in_dim %6171, %6172, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6174 = stablehlo.dynamic_broadcast_in_dim %213, %6172, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6175 = stablehlo.multiply %6173, %6174 : tensor + %dim_2178 = tensor.dim %6135, %c0 : tensor + %6176 = arith.index_cast %dim_2178 : index to i64 + %dim_2179 = tensor.dim %6171, %c0 : tensor + %6177 = arith.index_cast %dim_2179 : index to i64 + %6178 = arith.maxsi %6176, %6177 : i64 + %6179 = arith.index_cast %6178 : i64 to index + %from_elements_2180 = tensor.from_elements %6179, %c4096 : tensor<2xindex> + %6180 = stablehlo.dynamic_broadcast_in_dim %6135, %from_elements_2180, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2181 = tensor.dim %6180, %c0 : tensor + %6181 = arith.index_cast %dim_2181 : index to i64 + %from_elements_2182 = tensor.from_elements %6181, %c4096_i64 : tensor<2xi64> + %6182 = stablehlo.real_dynamic_slice %6175, %c_22, %from_elements_2182, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2183 = tensor.from_elements %6181, %c4096_i64, %c1_i64 : tensor<3xi64> + %6183 = stablehlo.dynamic_reshape %6180, %from_elements_2183 : (tensor, tensor<3xi64>) -> tensor + %6184 = stablehlo.dynamic_iota %from_elements_2183, dim = 1 : (tensor<3xi64>) -> tensor + %6185 = stablehlo.concatenate %6183, %6184, dim = 2 : (tensor, tensor) -> tensor + %6186 = "stablehlo.scatter"(%6123, %6185, %6182) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6187 = stablehlo.reshape %6186 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %6188 = stablehlo.add %5653, %6187 : tensor<3x1x4096xf32> + %6189 = stablehlo.broadcast_in_dim %6188, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %6190 = stablehlo.power %6189, %15 : tensor<3x1x4096xf32> + %6191 = stablehlo.reduce(%6190 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %6192 = stablehlo.reshape %6191 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %6193 = stablehlo.broadcast_in_dim %6192, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %6194 = stablehlo.divide %6193, %21 : tensor<3x1x1xf32> + %6195 = stablehlo.broadcast_in_dim %6194, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %6196 = stablehlo.add %6195, %25 : tensor<3x1x1xf32> + %6197 = stablehlo.rsqrt %6196 : tensor<3x1x1xf32> + %6198 = stablehlo.broadcast_in_dim %6197, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %6199 = stablehlo.multiply %6189, %6198 : tensor<3x1x4096xf32> + %6200 = stablehlo.broadcast_in_dim %6199, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %6201 = stablehlo.multiply %6200, %31 : tensor<3x1x4096xf32> + %6202 = stablehlo.reshape %6201 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %6203 = stablehlo.dot %6202, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %6204 = stablehlo.reshape %6203 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %6205 = stablehlo.dot %6202, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %6206 = stablehlo.reshape %6205 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %6207 = stablehlo.reshape %6204 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %6208 = stablehlo.transpose %6207, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %6209 = stablehlo.reshape %6206 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %6210 = stablehlo.transpose %6209, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %6211 = stablehlo.slice %arg20 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %6212 = stablehlo.slice %arg21 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %6213 = "stablehlo.gather"(%6211, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %6214 = stablehlo.reshape %6213 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %6215 = "stablehlo.gather"(%6212, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %6216 = stablehlo.reshape %6215 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %6217 = stablehlo.broadcast_in_dim %6208, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %6218 = stablehlo.broadcast_in_dim %6214, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %6219 = stablehlo.multiply %6217, %6218 : tensor<3x32x1x128xf32> + %6220 = stablehlo.slice %6208 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %6221 = stablehlo.slice %6208 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %6222 = stablehlo.negate %6221 : tensor<3x32x1x64xf32> + %6223 = stablehlo.concatenate %6222, %6220, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %6224 = stablehlo.broadcast_in_dim %6223, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %6225 = stablehlo.broadcast_in_dim %6216, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %6226 = stablehlo.multiply %6224, %6225 : tensor<3x32x1x128xf32> + %6227 = stablehlo.add %6219, %6226 : tensor<3x32x1x128xf32> + %6228 = stablehlo.broadcast_in_dim %6210, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %6229 = stablehlo.broadcast_in_dim %6214, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %6230 = stablehlo.multiply %6228, %6229 : tensor<3x8x1x128xf32> + %6231 = stablehlo.slice %6210 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %6232 = stablehlo.slice %6210 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %6233 = stablehlo.negate %6232 : tensor<3x8x1x64xf32> + %6234 = stablehlo.concatenate %6233, %6231, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %6235 = stablehlo.broadcast_in_dim %6234, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %6236 = stablehlo.broadcast_in_dim %6216, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %6237 = stablehlo.multiply %6235, %6236 : tensor<3x8x1x128xf32> + %6238 = stablehlo.add %6230, %6237 : tensor<3x8x1x128xf32> + %6239 = stablehlo.concatenate %arg85, %6238, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %6240 = stablehlo.concatenate %arg86, %6210, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %6241 = stablehlo.reshape %6239 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %6242 = stablehlo.broadcast_in_dim %6241, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %6243 = stablehlo.reshape %6242 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %6244 = stablehlo.reshape %6240 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %6245 = stablehlo.broadcast_in_dim %6244, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %6246 = stablehlo.reshape %6245 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %6247 = stablehlo.transpose %6243, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %6248 = stablehlo.reshape %6227 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %6249 = stablehlo.reshape %6247 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %6250 = stablehlo.broadcast_in_dim %6249, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %6251 = stablehlo.dot_general %6248, %6250, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %6252 = stablehlo.reshape %6251 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %6253 = stablehlo.broadcast_in_dim %6252, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %6254 = stablehlo.divide %6253, %89 : tensor<3x32x1x8xf32> + %6255 = stablehlo.custom_call @byteir.softmax(%6254) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %6256 = stablehlo.reshape %6255 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %6257 = stablehlo.reshape %6246 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %6258 = stablehlo.broadcast_in_dim %6257, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %6259 = stablehlo.dot_general %6256, %6258, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %6260 = stablehlo.reshape %6259 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %6261 = stablehlo.transpose %6260, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %6262 = stablehlo.reshape %6261 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %6263 = stablehlo.reshape %6262 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %6264 = stablehlo.dot %6263, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %6265 = stablehlo.reshape %6264 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %6266 = stablehlo.add %6188, %6265 : tensor<3x1x4096xf32> + %6267 = stablehlo.broadcast_in_dim %6266, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %6268 = stablehlo.power %6267, %15 : tensor<3x1x4096xf32> + %6269 = stablehlo.reduce(%6268 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %6270 = stablehlo.reshape %6269 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %6271 = stablehlo.broadcast_in_dim %6270, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %6272 = stablehlo.divide %6271, %21 : tensor<3x1x1xf32> + %6273 = stablehlo.broadcast_in_dim %6272, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %6274 = stablehlo.add %6273, %25 : tensor<3x1x1xf32> + %6275 = stablehlo.rsqrt %6274 : tensor<3x1x1xf32> + %6276 = stablehlo.broadcast_in_dim %6275, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %6277 = stablehlo.multiply %6267, %6276 : tensor<3x1x4096xf32> + %6278 = stablehlo.broadcast_in_dim %6277, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %6279 = stablehlo.multiply %6278, %31 : tensor<3x1x4096xf32> + %6280 = stablehlo.reshape %6279 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %6281 = stablehlo.dot %6280, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %6282 = stablehlo.custom_call @byteir.softmax(%6281) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %6283:2 = stablehlo.custom_call @byteir.top_k(%6282) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %6284 = stablehlo.reduce(%6283#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %6285 = stablehlo.reshape %6284 : (tensor<3xf32>) -> tensor<3x1xf32> + %6286 = stablehlo.broadcast_in_dim %6283#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %6287 = stablehlo.broadcast_in_dim %6285, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %6288 = stablehlo.divide %6286, %6287 : tensor<3x2xf32> + %6289 = stablehlo.reshape %6283#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %6290 = stablehlo.broadcast_in_dim %6289, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %6291 = stablehlo.compare EQ, %6290, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %6292 = stablehlo.convert %6291 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %6293 = stablehlo.transpose %6292, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %6294 = stablehlo.slice %6293 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6295 = stablehlo.reshape %6294 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6296 = stablehlo.custom_call @byteir.non_zero(%6295) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2184 = tensor.dim %6296, %c0 : tensor + %6297 = arith.index_cast %dim_2184 : index to i64 + %from_elements_2185 = tensor.from_elements %6297, %c1_i64 : tensor<2xi64> + %6298 = stablehlo.real_dynamic_slice %6296, %c_22, %from_elements_2185, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2186 = tensor.dim %6298, %c0 : tensor + %6299 = arith.index_cast %dim_2186 : index to i64 + %from_elements_2187 = tensor.from_elements %6299 : tensor<1xi64> + %6300 = stablehlo.dynamic_reshape %6298, %from_elements_2187 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2188 = tensor.from_elements %6297, %c2_i64 : tensor<2xi64> + %6301 = stablehlo.real_dynamic_slice %6296, %c_24, %from_elements_2188, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2189 = tensor.dim %6301, %c0 : tensor + %6302 = arith.index_cast %dim_2189 : index to i64 + %from_elements_2190 = tensor.from_elements %6302 : tensor<1xi64> + %6303 = stablehlo.dynamic_reshape %6301, %from_elements_2190 : (tensor, tensor<1xi64>) -> tensor + %6304 = stablehlo.reshape %6280 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_2191 = tensor.dim %6303, %c0 : tensor + %6305 = arith.index_cast %dim_2191 : index to i64 + %from_elements_2192 = tensor.from_elements %6305, %c1_i64 : tensor<2xi64> + %6306 = stablehlo.dynamic_reshape %6303, %from_elements_2192 : (tensor, tensor<2xi64>) -> tensor + %dim_2193 = tensor.dim %6306, %c0 : tensor + %6307 = arith.index_cast %dim_2193 : index to i64 + %from_elements_2194 = tensor.from_elements %c1_i64, %6307, %c4096_i64 : tensor<3xi64> + %6308 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2194, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2195 = tensor.dim %6308, %c1 : tensor<1x?x4096xi64> + %6309 = arith.index_cast %dim_2195 : index to i64 + %from_elements_2196 = tensor.from_elements %c1_i64, %6309, %c4096_i64, %c1_i64 : tensor<4xi64> + %6310 = stablehlo.dynamic_reshape %6308, %from_elements_2196 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6311 = stablehlo.dynamic_broadcast_in_dim %6306, %from_elements_2194, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2197 = tensor.dim %6311, %c1 : tensor<1x?x4096xi64> + %6312 = arith.index_cast %dim_2197 : index to i64 + %from_elements_2198 = tensor.from_elements %c1_i64, %6312, %c4096_i64, %c1_i64 : tensor<4xi64> + %6313 = stablehlo.dynamic_reshape %6311, %from_elements_2198 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6314 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2194, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2199 = tensor.dim %6314, %c1 : tensor<1x?x4096xi64> + %6315 = arith.index_cast %dim_2199 : index to i64 + %from_elements_2200 = tensor.from_elements %c1_i64, %6315, %c4096_i64, %c1_i64 : tensor<4xi64> + %6316 = stablehlo.dynamic_reshape %6314, %from_elements_2200 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6317 = stablehlo.concatenate %6310, %6313, %6316, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6318 = "stablehlo.gather"(%6304, %6317) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6319 = shape.shape_of %6318 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6320 = shape.num_elements %6319 : tensor<3xindex> -> index + %6321 = stablehlo.compute_reshape_shape %6320, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6322 = stablehlo.dynamic_reshape %6318, %6321 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6323 = stablehlo.dot %6322, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6324 = stablehlo.logistic %6323 : tensor + %6325 = shape.shape_of %6324 : tensor -> tensor<2xindex> + %6326 = shape.shape_of %6323 : tensor -> tensor<2xindex> + %6327 = shape.cstr_broadcastable %6325, %6326 : tensor<2xindex>, tensor<2xindex> + %6328 = shape.assuming %6327 -> (tensor) { + %19688 = shape.broadcast %6325, %6326 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6324, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6323, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6329 = shape.shape_of %6328 : tensor -> tensor<2xindex> + %6330 = shape.cstr_broadcastable %6329, %6326 : tensor<2xindex>, tensor<2xindex> + %6331 = shape.assuming %6330 -> (tensor) { + %19688 = shape.broadcast %6329, %6326 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6328, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6323, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6332 = stablehlo.dot %6331, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %6333 = stablehlo.reshape %6288 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_2201 = tensor.dim %6303, %c0 : tensor + %6334 = arith.index_cast %dim_2201 : index to i64 + %from_elements_2202 = tensor.from_elements %6334, %c1_i64 : tensor<2xi64> + %6335 = stablehlo.dynamic_reshape %6303, %from_elements_2202 : (tensor, tensor<2xi64>) -> tensor + %dim_2203 = tensor.dim %6300, %c0 : tensor + %6336 = arith.index_cast %dim_2203 : index to i64 + %from_elements_2204 = tensor.from_elements %6336, %c1_i64 : tensor<2xi64> + %6337 = stablehlo.dynamic_reshape %6300, %from_elements_2204 : (tensor, tensor<2xi64>) -> tensor + %6338 = stablehlo.concatenate %6335, %6337, dim = 1 : (tensor, tensor) -> tensor + %6339 = "stablehlo.gather"(%6333, %6338) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6340 = shape.shape_of %6332 : tensor -> tensor<2xindex> + %6341 = shape.shape_of %6339 : tensor -> tensor<2xindex> + %6342 = shape.cstr_broadcastable %6340, %6341 : tensor<2xindex>, tensor<2xindex> + %6343 = shape.assuming %6342 -> (tensor) { + %19688 = shape.broadcast %6340, %6341 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6332, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6339, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6344 = shape.shape_of %6343 : tensor -> tensor<2xindex> + %6345 = stablehlo.dynamic_broadcast_in_dim %6343, %6344, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6346 = stablehlo.dynamic_broadcast_in_dim %213, %6344, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6347 = stablehlo.multiply %6345, %6346 : tensor + %dim_2205 = tensor.dim %6306, %c0 : tensor + %6348 = arith.index_cast %dim_2205 : index to i64 + %dim_2206 = tensor.dim %6343, %c0 : tensor + %6349 = arith.index_cast %dim_2206 : index to i64 + %6350 = arith.maxsi %6348, %6349 : i64 + %6351 = arith.index_cast %6350 : i64 to index + %from_elements_2207 = tensor.from_elements %6351, %c4096 : tensor<2xindex> + %6352 = stablehlo.dynamic_broadcast_in_dim %6306, %from_elements_2207, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2208 = tensor.dim %6352, %c0 : tensor + %6353 = arith.index_cast %dim_2208 : index to i64 + %from_elements_2209 = tensor.from_elements %6353, %c4096_i64 : tensor<2xi64> + %6354 = stablehlo.real_dynamic_slice %6347, %c_22, %from_elements_2209, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2210 = tensor.from_elements %6353, %c4096_i64, %c1_i64 : tensor<3xi64> + %6355 = stablehlo.dynamic_reshape %6352, %from_elements_2210 : (tensor, tensor<3xi64>) -> tensor + %6356 = stablehlo.dynamic_iota %from_elements_2210, dim = 1 : (tensor<3xi64>) -> tensor + %6357 = stablehlo.concatenate %6355, %6356, dim = 2 : (tensor, tensor) -> tensor + %6358 = "stablehlo.scatter"(%cst_2, %6357, %6354) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6359 = stablehlo.slice %6293 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6360 = stablehlo.reshape %6359 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6361 = stablehlo.custom_call @byteir.non_zero(%6360) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2211 = tensor.dim %6361, %c0 : tensor + %6362 = arith.index_cast %dim_2211 : index to i64 + %from_elements_2212 = tensor.from_elements %6362, %c1_i64 : tensor<2xi64> + %6363 = stablehlo.real_dynamic_slice %6361, %c_22, %from_elements_2212, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2213 = tensor.dim %6363, %c0 : tensor + %6364 = arith.index_cast %dim_2213 : index to i64 + %from_elements_2214 = tensor.from_elements %6364 : tensor<1xi64> + %6365 = stablehlo.dynamic_reshape %6363, %from_elements_2214 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2215 = tensor.from_elements %6362, %c2_i64 : tensor<2xi64> + %6366 = stablehlo.real_dynamic_slice %6361, %c_24, %from_elements_2215, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2216 = tensor.dim %6366, %c0 : tensor + %6367 = arith.index_cast %dim_2216 : index to i64 + %from_elements_2217 = tensor.from_elements %6367 : tensor<1xi64> + %6368 = stablehlo.dynamic_reshape %6366, %from_elements_2217 : (tensor, tensor<1xi64>) -> tensor + %dim_2218 = tensor.dim %6368, %c0 : tensor + %6369 = arith.index_cast %dim_2218 : index to i64 + %from_elements_2219 = tensor.from_elements %6369, %c1_i64 : tensor<2xi64> + %6370 = stablehlo.dynamic_reshape %6368, %from_elements_2219 : (tensor, tensor<2xi64>) -> tensor + %dim_2220 = tensor.dim %6370, %c0 : tensor + %6371 = arith.index_cast %dim_2220 : index to i64 + %from_elements_2221 = tensor.from_elements %c1_i64, %6371, %c4096_i64 : tensor<3xi64> + %6372 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2221, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2222 = tensor.dim %6372, %c1 : tensor<1x?x4096xi64> + %6373 = arith.index_cast %dim_2222 : index to i64 + %from_elements_2223 = tensor.from_elements %c1_i64, %6373, %c4096_i64, %c1_i64 : tensor<4xi64> + %6374 = stablehlo.dynamic_reshape %6372, %from_elements_2223 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6375 = stablehlo.dynamic_broadcast_in_dim %6370, %from_elements_2221, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2224 = tensor.dim %6375, %c1 : tensor<1x?x4096xi64> + %6376 = arith.index_cast %dim_2224 : index to i64 + %from_elements_2225 = tensor.from_elements %c1_i64, %6376, %c4096_i64, %c1_i64 : tensor<4xi64> + %6377 = stablehlo.dynamic_reshape %6375, %from_elements_2225 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6378 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2221, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2226 = tensor.dim %6378, %c1 : tensor<1x?x4096xi64> + %6379 = arith.index_cast %dim_2226 : index to i64 + %from_elements_2227 = tensor.from_elements %c1_i64, %6379, %c4096_i64, %c1_i64 : tensor<4xi64> + %6380 = stablehlo.dynamic_reshape %6378, %from_elements_2227 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6381 = stablehlo.concatenate %6374, %6377, %6380, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6382 = "stablehlo.gather"(%6304, %6381) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6383 = shape.shape_of %6382 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6384 = shape.num_elements %6383 : tensor<3xindex> -> index + %6385 = stablehlo.compute_reshape_shape %6384, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6386 = stablehlo.dynamic_reshape %6382, %6385 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6387 = stablehlo.dot %6386, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6388 = stablehlo.logistic %6387 : tensor + %6389 = shape.shape_of %6388 : tensor -> tensor<2xindex> + %6390 = shape.shape_of %6387 : tensor -> tensor<2xindex> + %6391 = shape.cstr_broadcastable %6389, %6390 : tensor<2xindex>, tensor<2xindex> + %6392 = shape.assuming %6391 -> (tensor) { + %19688 = shape.broadcast %6389, %6390 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6388, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6387, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6393 = shape.shape_of %6392 : tensor -> tensor<2xindex> + %6394 = shape.cstr_broadcastable %6393, %6390 : tensor<2xindex>, tensor<2xindex> + %6395 = shape.assuming %6394 -> (tensor) { + %19688 = shape.broadcast %6393, %6390 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6392, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6387, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6396 = stablehlo.dot %6395, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2228 = tensor.dim %6368, %c0 : tensor + %6397 = arith.index_cast %dim_2228 : index to i64 + %from_elements_2229 = tensor.from_elements %6397, %c1_i64 : tensor<2xi64> + %6398 = stablehlo.dynamic_reshape %6368, %from_elements_2229 : (tensor, tensor<2xi64>) -> tensor + %dim_2230 = tensor.dim %6365, %c0 : tensor + %6399 = arith.index_cast %dim_2230 : index to i64 + %from_elements_2231 = tensor.from_elements %6399, %c1_i64 : tensor<2xi64> + %6400 = stablehlo.dynamic_reshape %6365, %from_elements_2231 : (tensor, tensor<2xi64>) -> tensor + %6401 = stablehlo.concatenate %6398, %6400, dim = 1 : (tensor, tensor) -> tensor + %6402 = "stablehlo.gather"(%6333, %6401) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6403 = shape.shape_of %6396 : tensor -> tensor<2xindex> + %6404 = shape.shape_of %6402 : tensor -> tensor<2xindex> + %6405 = shape.cstr_broadcastable %6403, %6404 : tensor<2xindex>, tensor<2xindex> + %6406 = shape.assuming %6405 -> (tensor) { + %19688 = shape.broadcast %6403, %6404 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6396, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6402, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6407 = shape.shape_of %6406 : tensor -> tensor<2xindex> + %6408 = stablehlo.dynamic_broadcast_in_dim %6406, %6407, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6409 = stablehlo.dynamic_broadcast_in_dim %213, %6407, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6410 = stablehlo.multiply %6408, %6409 : tensor + %dim_2232 = tensor.dim %6370, %c0 : tensor + %6411 = arith.index_cast %dim_2232 : index to i64 + %dim_2233 = tensor.dim %6406, %c0 : tensor + %6412 = arith.index_cast %dim_2233 : index to i64 + %6413 = arith.maxsi %6411, %6412 : i64 + %6414 = arith.index_cast %6413 : i64 to index + %from_elements_2234 = tensor.from_elements %6414, %c4096 : tensor<2xindex> + %6415 = stablehlo.dynamic_broadcast_in_dim %6370, %from_elements_2234, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2235 = tensor.dim %6415, %c0 : tensor + %6416 = arith.index_cast %dim_2235 : index to i64 + %from_elements_2236 = tensor.from_elements %6416, %c4096_i64 : tensor<2xi64> + %6417 = stablehlo.real_dynamic_slice %6410, %c_22, %from_elements_2236, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2237 = tensor.from_elements %6416, %c4096_i64, %c1_i64 : tensor<3xi64> + %6418 = stablehlo.dynamic_reshape %6415, %from_elements_2237 : (tensor, tensor<3xi64>) -> tensor + %6419 = stablehlo.dynamic_iota %from_elements_2237, dim = 1 : (tensor<3xi64>) -> tensor + %6420 = stablehlo.concatenate %6418, %6419, dim = 2 : (tensor, tensor) -> tensor + %6421 = "stablehlo.scatter"(%6358, %6420, %6417) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6422 = stablehlo.slice %6293 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6423 = stablehlo.reshape %6422 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6424 = stablehlo.custom_call @byteir.non_zero(%6423) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2238 = tensor.dim %6424, %c0 : tensor + %6425 = arith.index_cast %dim_2238 : index to i64 + %from_elements_2239 = tensor.from_elements %6425, %c1_i64 : tensor<2xi64> + %6426 = stablehlo.real_dynamic_slice %6424, %c_22, %from_elements_2239, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2240 = tensor.dim %6426, %c0 : tensor + %6427 = arith.index_cast %dim_2240 : index to i64 + %from_elements_2241 = tensor.from_elements %6427 : tensor<1xi64> + %6428 = stablehlo.dynamic_reshape %6426, %from_elements_2241 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2242 = tensor.from_elements %6425, %c2_i64 : tensor<2xi64> + %6429 = stablehlo.real_dynamic_slice %6424, %c_24, %from_elements_2242, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2243 = tensor.dim %6429, %c0 : tensor + %6430 = arith.index_cast %dim_2243 : index to i64 + %from_elements_2244 = tensor.from_elements %6430 : tensor<1xi64> + %6431 = stablehlo.dynamic_reshape %6429, %from_elements_2244 : (tensor, tensor<1xi64>) -> tensor + %dim_2245 = tensor.dim %6431, %c0 : tensor + %6432 = arith.index_cast %dim_2245 : index to i64 + %from_elements_2246 = tensor.from_elements %6432, %c1_i64 : tensor<2xi64> + %6433 = stablehlo.dynamic_reshape %6431, %from_elements_2246 : (tensor, tensor<2xi64>) -> tensor + %dim_2247 = tensor.dim %6433, %c0 : tensor + %6434 = arith.index_cast %dim_2247 : index to i64 + %from_elements_2248 = tensor.from_elements %c1_i64, %6434, %c4096_i64 : tensor<3xi64> + %6435 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2248, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2249 = tensor.dim %6435, %c1 : tensor<1x?x4096xi64> + %6436 = arith.index_cast %dim_2249 : index to i64 + %from_elements_2250 = tensor.from_elements %c1_i64, %6436, %c4096_i64, %c1_i64 : tensor<4xi64> + %6437 = stablehlo.dynamic_reshape %6435, %from_elements_2250 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6438 = stablehlo.dynamic_broadcast_in_dim %6433, %from_elements_2248, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2251 = tensor.dim %6438, %c1 : tensor<1x?x4096xi64> + %6439 = arith.index_cast %dim_2251 : index to i64 + %from_elements_2252 = tensor.from_elements %c1_i64, %6439, %c4096_i64, %c1_i64 : tensor<4xi64> + %6440 = stablehlo.dynamic_reshape %6438, %from_elements_2252 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6441 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2248, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2253 = tensor.dim %6441, %c1 : tensor<1x?x4096xi64> + %6442 = arith.index_cast %dim_2253 : index to i64 + %from_elements_2254 = tensor.from_elements %c1_i64, %6442, %c4096_i64, %c1_i64 : tensor<4xi64> + %6443 = stablehlo.dynamic_reshape %6441, %from_elements_2254 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6444 = stablehlo.concatenate %6437, %6440, %6443, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6445 = "stablehlo.gather"(%6304, %6444) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6446 = shape.shape_of %6445 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6447 = shape.num_elements %6446 : tensor<3xindex> -> index + %6448 = stablehlo.compute_reshape_shape %6447, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6449 = stablehlo.dynamic_reshape %6445, %6448 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6450 = stablehlo.dot %6449, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6451 = stablehlo.logistic %6450 : tensor + %6452 = shape.shape_of %6451 : tensor -> tensor<2xindex> + %6453 = shape.shape_of %6450 : tensor -> tensor<2xindex> + %6454 = shape.cstr_broadcastable %6452, %6453 : tensor<2xindex>, tensor<2xindex> + %6455 = shape.assuming %6454 -> (tensor) { + %19688 = shape.broadcast %6452, %6453 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6451, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6450, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6456 = shape.shape_of %6455 : tensor -> tensor<2xindex> + %6457 = shape.cstr_broadcastable %6456, %6453 : tensor<2xindex>, tensor<2xindex> + %6458 = shape.assuming %6457 -> (tensor) { + %19688 = shape.broadcast %6456, %6453 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6455, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6450, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6459 = stablehlo.dot %6458, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2255 = tensor.dim %6431, %c0 : tensor + %6460 = arith.index_cast %dim_2255 : index to i64 + %from_elements_2256 = tensor.from_elements %6460, %c1_i64 : tensor<2xi64> + %6461 = stablehlo.dynamic_reshape %6431, %from_elements_2256 : (tensor, tensor<2xi64>) -> tensor + %dim_2257 = tensor.dim %6428, %c0 : tensor + %6462 = arith.index_cast %dim_2257 : index to i64 + %from_elements_2258 = tensor.from_elements %6462, %c1_i64 : tensor<2xi64> + %6463 = stablehlo.dynamic_reshape %6428, %from_elements_2258 : (tensor, tensor<2xi64>) -> tensor + %6464 = stablehlo.concatenate %6461, %6463, dim = 1 : (tensor, tensor) -> tensor + %6465 = "stablehlo.gather"(%6333, %6464) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6466 = shape.shape_of %6459 : tensor -> tensor<2xindex> + %6467 = shape.shape_of %6465 : tensor -> tensor<2xindex> + %6468 = shape.cstr_broadcastable %6466, %6467 : tensor<2xindex>, tensor<2xindex> + %6469 = shape.assuming %6468 -> (tensor) { + %19688 = shape.broadcast %6466, %6467 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6459, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6465, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6470 = shape.shape_of %6469 : tensor -> tensor<2xindex> + %6471 = stablehlo.dynamic_broadcast_in_dim %6469, %6470, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6472 = stablehlo.dynamic_broadcast_in_dim %213, %6470, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6473 = stablehlo.multiply %6471, %6472 : tensor + %dim_2259 = tensor.dim %6433, %c0 : tensor + %6474 = arith.index_cast %dim_2259 : index to i64 + %dim_2260 = tensor.dim %6469, %c0 : tensor + %6475 = arith.index_cast %dim_2260 : index to i64 + %6476 = arith.maxsi %6474, %6475 : i64 + %6477 = arith.index_cast %6476 : i64 to index + %from_elements_2261 = tensor.from_elements %6477, %c4096 : tensor<2xindex> + %6478 = stablehlo.dynamic_broadcast_in_dim %6433, %from_elements_2261, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2262 = tensor.dim %6478, %c0 : tensor + %6479 = arith.index_cast %dim_2262 : index to i64 + %from_elements_2263 = tensor.from_elements %6479, %c4096_i64 : tensor<2xi64> + %6480 = stablehlo.real_dynamic_slice %6473, %c_22, %from_elements_2263, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2264 = tensor.from_elements %6479, %c4096_i64, %c1_i64 : tensor<3xi64> + %6481 = stablehlo.dynamic_reshape %6478, %from_elements_2264 : (tensor, tensor<3xi64>) -> tensor + %6482 = stablehlo.dynamic_iota %from_elements_2264, dim = 1 : (tensor<3xi64>) -> tensor + %6483 = stablehlo.concatenate %6481, %6482, dim = 2 : (tensor, tensor) -> tensor + %6484 = "stablehlo.scatter"(%6421, %6483, %6480) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6485 = stablehlo.slice %6293 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6486 = stablehlo.reshape %6485 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6487 = stablehlo.custom_call @byteir.non_zero(%6486) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2265 = tensor.dim %6487, %c0 : tensor + %6488 = arith.index_cast %dim_2265 : index to i64 + %from_elements_2266 = tensor.from_elements %6488, %c1_i64 : tensor<2xi64> + %6489 = stablehlo.real_dynamic_slice %6487, %c_22, %from_elements_2266, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2267 = tensor.dim %6489, %c0 : tensor + %6490 = arith.index_cast %dim_2267 : index to i64 + %from_elements_2268 = tensor.from_elements %6490 : tensor<1xi64> + %6491 = stablehlo.dynamic_reshape %6489, %from_elements_2268 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2269 = tensor.from_elements %6488, %c2_i64 : tensor<2xi64> + %6492 = stablehlo.real_dynamic_slice %6487, %c_24, %from_elements_2269, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2270 = tensor.dim %6492, %c0 : tensor + %6493 = arith.index_cast %dim_2270 : index to i64 + %from_elements_2271 = tensor.from_elements %6493 : tensor<1xi64> + %6494 = stablehlo.dynamic_reshape %6492, %from_elements_2271 : (tensor, tensor<1xi64>) -> tensor + %dim_2272 = tensor.dim %6494, %c0 : tensor + %6495 = arith.index_cast %dim_2272 : index to i64 + %from_elements_2273 = tensor.from_elements %6495, %c1_i64 : tensor<2xi64> + %6496 = stablehlo.dynamic_reshape %6494, %from_elements_2273 : (tensor, tensor<2xi64>) -> tensor + %dim_2274 = tensor.dim %6496, %c0 : tensor + %6497 = arith.index_cast %dim_2274 : index to i64 + %from_elements_2275 = tensor.from_elements %c1_i64, %6497, %c4096_i64 : tensor<3xi64> + %6498 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2275, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2276 = tensor.dim %6498, %c1 : tensor<1x?x4096xi64> + %6499 = arith.index_cast %dim_2276 : index to i64 + %from_elements_2277 = tensor.from_elements %c1_i64, %6499, %c4096_i64, %c1_i64 : tensor<4xi64> + %6500 = stablehlo.dynamic_reshape %6498, %from_elements_2277 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6501 = stablehlo.dynamic_broadcast_in_dim %6496, %from_elements_2275, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2278 = tensor.dim %6501, %c1 : tensor<1x?x4096xi64> + %6502 = arith.index_cast %dim_2278 : index to i64 + %from_elements_2279 = tensor.from_elements %c1_i64, %6502, %c4096_i64, %c1_i64 : tensor<4xi64> + %6503 = stablehlo.dynamic_reshape %6501, %from_elements_2279 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6504 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2275, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2280 = tensor.dim %6504, %c1 : tensor<1x?x4096xi64> + %6505 = arith.index_cast %dim_2280 : index to i64 + %from_elements_2281 = tensor.from_elements %c1_i64, %6505, %c4096_i64, %c1_i64 : tensor<4xi64> + %6506 = stablehlo.dynamic_reshape %6504, %from_elements_2281 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6507 = stablehlo.concatenate %6500, %6503, %6506, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6508 = "stablehlo.gather"(%6304, %6507) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6509 = shape.shape_of %6508 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6510 = shape.num_elements %6509 : tensor<3xindex> -> index + %6511 = stablehlo.compute_reshape_shape %6510, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6512 = stablehlo.dynamic_reshape %6508, %6511 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6513 = stablehlo.dot %6512, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6514 = stablehlo.logistic %6513 : tensor + %6515 = shape.shape_of %6514 : tensor -> tensor<2xindex> + %6516 = shape.shape_of %6513 : tensor -> tensor<2xindex> + %6517 = shape.cstr_broadcastable %6515, %6516 : tensor<2xindex>, tensor<2xindex> + %6518 = shape.assuming %6517 -> (tensor) { + %19688 = shape.broadcast %6515, %6516 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6514, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6513, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6519 = shape.shape_of %6518 : tensor -> tensor<2xindex> + %6520 = shape.cstr_broadcastable %6519, %6516 : tensor<2xindex>, tensor<2xindex> + %6521 = shape.assuming %6520 -> (tensor) { + %19688 = shape.broadcast %6519, %6516 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6518, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6513, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6522 = stablehlo.dot %6521, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2282 = tensor.dim %6494, %c0 : tensor + %6523 = arith.index_cast %dim_2282 : index to i64 + %from_elements_2283 = tensor.from_elements %6523, %c1_i64 : tensor<2xi64> + %6524 = stablehlo.dynamic_reshape %6494, %from_elements_2283 : (tensor, tensor<2xi64>) -> tensor + %dim_2284 = tensor.dim %6491, %c0 : tensor + %6525 = arith.index_cast %dim_2284 : index to i64 + %from_elements_2285 = tensor.from_elements %6525, %c1_i64 : tensor<2xi64> + %6526 = stablehlo.dynamic_reshape %6491, %from_elements_2285 : (tensor, tensor<2xi64>) -> tensor + %6527 = stablehlo.concatenate %6524, %6526, dim = 1 : (tensor, tensor) -> tensor + %6528 = "stablehlo.gather"(%6333, %6527) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6529 = shape.shape_of %6522 : tensor -> tensor<2xindex> + %6530 = shape.shape_of %6528 : tensor -> tensor<2xindex> + %6531 = shape.cstr_broadcastable %6529, %6530 : tensor<2xindex>, tensor<2xindex> + %6532 = shape.assuming %6531 -> (tensor) { + %19688 = shape.broadcast %6529, %6530 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6522, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6528, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6533 = shape.shape_of %6532 : tensor -> tensor<2xindex> + %6534 = stablehlo.dynamic_broadcast_in_dim %6532, %6533, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6535 = stablehlo.dynamic_broadcast_in_dim %213, %6533, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6536 = stablehlo.multiply %6534, %6535 : tensor + %dim_2286 = tensor.dim %6496, %c0 : tensor + %6537 = arith.index_cast %dim_2286 : index to i64 + %dim_2287 = tensor.dim %6532, %c0 : tensor + %6538 = arith.index_cast %dim_2287 : index to i64 + %6539 = arith.maxsi %6537, %6538 : i64 + %6540 = arith.index_cast %6539 : i64 to index + %from_elements_2288 = tensor.from_elements %6540, %c4096 : tensor<2xindex> + %6541 = stablehlo.dynamic_broadcast_in_dim %6496, %from_elements_2288, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2289 = tensor.dim %6541, %c0 : tensor + %6542 = arith.index_cast %dim_2289 : index to i64 + %from_elements_2290 = tensor.from_elements %6542, %c4096_i64 : tensor<2xi64> + %6543 = stablehlo.real_dynamic_slice %6536, %c_22, %from_elements_2290, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2291 = tensor.from_elements %6542, %c4096_i64, %c1_i64 : tensor<3xi64> + %6544 = stablehlo.dynamic_reshape %6541, %from_elements_2291 : (tensor, tensor<3xi64>) -> tensor + %6545 = stablehlo.dynamic_iota %from_elements_2291, dim = 1 : (tensor<3xi64>) -> tensor + %6546 = stablehlo.concatenate %6544, %6545, dim = 2 : (tensor, tensor) -> tensor + %6547 = "stablehlo.scatter"(%6484, %6546, %6543) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6548 = stablehlo.slice %6293 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6549 = stablehlo.reshape %6548 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6550 = stablehlo.custom_call @byteir.non_zero(%6549) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2292 = tensor.dim %6550, %c0 : tensor + %6551 = arith.index_cast %dim_2292 : index to i64 + %from_elements_2293 = tensor.from_elements %6551, %c1_i64 : tensor<2xi64> + %6552 = stablehlo.real_dynamic_slice %6550, %c_22, %from_elements_2293, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2294 = tensor.dim %6552, %c0 : tensor + %6553 = arith.index_cast %dim_2294 : index to i64 + %from_elements_2295 = tensor.from_elements %6553 : tensor<1xi64> + %6554 = stablehlo.dynamic_reshape %6552, %from_elements_2295 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2296 = tensor.from_elements %6551, %c2_i64 : tensor<2xi64> + %6555 = stablehlo.real_dynamic_slice %6550, %c_24, %from_elements_2296, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2297 = tensor.dim %6555, %c0 : tensor + %6556 = arith.index_cast %dim_2297 : index to i64 + %from_elements_2298 = tensor.from_elements %6556 : tensor<1xi64> + %6557 = stablehlo.dynamic_reshape %6555, %from_elements_2298 : (tensor, tensor<1xi64>) -> tensor + %dim_2299 = tensor.dim %6557, %c0 : tensor + %6558 = arith.index_cast %dim_2299 : index to i64 + %from_elements_2300 = tensor.from_elements %6558, %c1_i64 : tensor<2xi64> + %6559 = stablehlo.dynamic_reshape %6557, %from_elements_2300 : (tensor, tensor<2xi64>) -> tensor + %dim_2301 = tensor.dim %6559, %c0 : tensor + %6560 = arith.index_cast %dim_2301 : index to i64 + %from_elements_2302 = tensor.from_elements %c1_i64, %6560, %c4096_i64 : tensor<3xi64> + %6561 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2302, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2303 = tensor.dim %6561, %c1 : tensor<1x?x4096xi64> + %6562 = arith.index_cast %dim_2303 : index to i64 + %from_elements_2304 = tensor.from_elements %c1_i64, %6562, %c4096_i64, %c1_i64 : tensor<4xi64> + %6563 = stablehlo.dynamic_reshape %6561, %from_elements_2304 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6564 = stablehlo.dynamic_broadcast_in_dim %6559, %from_elements_2302, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2305 = tensor.dim %6564, %c1 : tensor<1x?x4096xi64> + %6565 = arith.index_cast %dim_2305 : index to i64 + %from_elements_2306 = tensor.from_elements %c1_i64, %6565, %c4096_i64, %c1_i64 : tensor<4xi64> + %6566 = stablehlo.dynamic_reshape %6564, %from_elements_2306 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6567 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2302, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2307 = tensor.dim %6567, %c1 : tensor<1x?x4096xi64> + %6568 = arith.index_cast %dim_2307 : index to i64 + %from_elements_2308 = tensor.from_elements %c1_i64, %6568, %c4096_i64, %c1_i64 : tensor<4xi64> + %6569 = stablehlo.dynamic_reshape %6567, %from_elements_2308 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6570 = stablehlo.concatenate %6563, %6566, %6569, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6571 = "stablehlo.gather"(%6304, %6570) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6572 = shape.shape_of %6571 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6573 = shape.num_elements %6572 : tensor<3xindex> -> index + %6574 = stablehlo.compute_reshape_shape %6573, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6575 = stablehlo.dynamic_reshape %6571, %6574 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6576 = stablehlo.dot %6575, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6577 = stablehlo.logistic %6576 : tensor + %6578 = shape.shape_of %6577 : tensor -> tensor<2xindex> + %6579 = shape.shape_of %6576 : tensor -> tensor<2xindex> + %6580 = shape.cstr_broadcastable %6578, %6579 : tensor<2xindex>, tensor<2xindex> + %6581 = shape.assuming %6580 -> (tensor) { + %19688 = shape.broadcast %6578, %6579 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6577, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6576, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6582 = shape.shape_of %6581 : tensor -> tensor<2xindex> + %6583 = shape.cstr_broadcastable %6582, %6579 : tensor<2xindex>, tensor<2xindex> + %6584 = shape.assuming %6583 -> (tensor) { + %19688 = shape.broadcast %6582, %6579 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6581, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6576, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6585 = stablehlo.dot %6584, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2309 = tensor.dim %6557, %c0 : tensor + %6586 = arith.index_cast %dim_2309 : index to i64 + %from_elements_2310 = tensor.from_elements %6586, %c1_i64 : tensor<2xi64> + %6587 = stablehlo.dynamic_reshape %6557, %from_elements_2310 : (tensor, tensor<2xi64>) -> tensor + %dim_2311 = tensor.dim %6554, %c0 : tensor + %6588 = arith.index_cast %dim_2311 : index to i64 + %from_elements_2312 = tensor.from_elements %6588, %c1_i64 : tensor<2xi64> + %6589 = stablehlo.dynamic_reshape %6554, %from_elements_2312 : (tensor, tensor<2xi64>) -> tensor + %6590 = stablehlo.concatenate %6587, %6589, dim = 1 : (tensor, tensor) -> tensor + %6591 = "stablehlo.gather"(%6333, %6590) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6592 = shape.shape_of %6585 : tensor -> tensor<2xindex> + %6593 = shape.shape_of %6591 : tensor -> tensor<2xindex> + %6594 = shape.cstr_broadcastable %6592, %6593 : tensor<2xindex>, tensor<2xindex> + %6595 = shape.assuming %6594 -> (tensor) { + %19688 = shape.broadcast %6592, %6593 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6585, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6591, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6596 = shape.shape_of %6595 : tensor -> tensor<2xindex> + %6597 = stablehlo.dynamic_broadcast_in_dim %6595, %6596, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6598 = stablehlo.dynamic_broadcast_in_dim %213, %6596, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6599 = stablehlo.multiply %6597, %6598 : tensor + %dim_2313 = tensor.dim %6559, %c0 : tensor + %6600 = arith.index_cast %dim_2313 : index to i64 + %dim_2314 = tensor.dim %6595, %c0 : tensor + %6601 = arith.index_cast %dim_2314 : index to i64 + %6602 = arith.maxsi %6600, %6601 : i64 + %6603 = arith.index_cast %6602 : i64 to index + %from_elements_2315 = tensor.from_elements %6603, %c4096 : tensor<2xindex> + %6604 = stablehlo.dynamic_broadcast_in_dim %6559, %from_elements_2315, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2316 = tensor.dim %6604, %c0 : tensor + %6605 = arith.index_cast %dim_2316 : index to i64 + %from_elements_2317 = tensor.from_elements %6605, %c4096_i64 : tensor<2xi64> + %6606 = stablehlo.real_dynamic_slice %6599, %c_22, %from_elements_2317, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2318 = tensor.from_elements %6605, %c4096_i64, %c1_i64 : tensor<3xi64> + %6607 = stablehlo.dynamic_reshape %6604, %from_elements_2318 : (tensor, tensor<3xi64>) -> tensor + %6608 = stablehlo.dynamic_iota %from_elements_2318, dim = 1 : (tensor<3xi64>) -> tensor + %6609 = stablehlo.concatenate %6607, %6608, dim = 2 : (tensor, tensor) -> tensor + %6610 = "stablehlo.scatter"(%6547, %6609, %6606) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6611 = stablehlo.slice %6293 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6612 = stablehlo.reshape %6611 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6613 = stablehlo.custom_call @byteir.non_zero(%6612) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2319 = tensor.dim %6613, %c0 : tensor + %6614 = arith.index_cast %dim_2319 : index to i64 + %from_elements_2320 = tensor.from_elements %6614, %c1_i64 : tensor<2xi64> + %6615 = stablehlo.real_dynamic_slice %6613, %c_22, %from_elements_2320, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2321 = tensor.dim %6615, %c0 : tensor + %6616 = arith.index_cast %dim_2321 : index to i64 + %from_elements_2322 = tensor.from_elements %6616 : tensor<1xi64> + %6617 = stablehlo.dynamic_reshape %6615, %from_elements_2322 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2323 = tensor.from_elements %6614, %c2_i64 : tensor<2xi64> + %6618 = stablehlo.real_dynamic_slice %6613, %c_24, %from_elements_2323, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2324 = tensor.dim %6618, %c0 : tensor + %6619 = arith.index_cast %dim_2324 : index to i64 + %from_elements_2325 = tensor.from_elements %6619 : tensor<1xi64> + %6620 = stablehlo.dynamic_reshape %6618, %from_elements_2325 : (tensor, tensor<1xi64>) -> tensor + %dim_2326 = tensor.dim %6620, %c0 : tensor + %6621 = arith.index_cast %dim_2326 : index to i64 + %from_elements_2327 = tensor.from_elements %6621, %c1_i64 : tensor<2xi64> + %6622 = stablehlo.dynamic_reshape %6620, %from_elements_2327 : (tensor, tensor<2xi64>) -> tensor + %dim_2328 = tensor.dim %6622, %c0 : tensor + %6623 = arith.index_cast %dim_2328 : index to i64 + %from_elements_2329 = tensor.from_elements %c1_i64, %6623, %c4096_i64 : tensor<3xi64> + %6624 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2329, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2330 = tensor.dim %6624, %c1 : tensor<1x?x4096xi64> + %6625 = arith.index_cast %dim_2330 : index to i64 + %from_elements_2331 = tensor.from_elements %c1_i64, %6625, %c4096_i64, %c1_i64 : tensor<4xi64> + %6626 = stablehlo.dynamic_reshape %6624, %from_elements_2331 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6627 = stablehlo.dynamic_broadcast_in_dim %6622, %from_elements_2329, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2332 = tensor.dim %6627, %c1 : tensor<1x?x4096xi64> + %6628 = arith.index_cast %dim_2332 : index to i64 + %from_elements_2333 = tensor.from_elements %c1_i64, %6628, %c4096_i64, %c1_i64 : tensor<4xi64> + %6629 = stablehlo.dynamic_reshape %6627, %from_elements_2333 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6630 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2329, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2334 = tensor.dim %6630, %c1 : tensor<1x?x4096xi64> + %6631 = arith.index_cast %dim_2334 : index to i64 + %from_elements_2335 = tensor.from_elements %c1_i64, %6631, %c4096_i64, %c1_i64 : tensor<4xi64> + %6632 = stablehlo.dynamic_reshape %6630, %from_elements_2335 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6633 = stablehlo.concatenate %6626, %6629, %6632, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6634 = "stablehlo.gather"(%6304, %6633) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6635 = shape.shape_of %6634 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6636 = shape.num_elements %6635 : tensor<3xindex> -> index + %6637 = stablehlo.compute_reshape_shape %6636, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6638 = stablehlo.dynamic_reshape %6634, %6637 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6639 = stablehlo.dot %6638, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6640 = stablehlo.logistic %6639 : tensor + %6641 = shape.shape_of %6640 : tensor -> tensor<2xindex> + %6642 = shape.shape_of %6639 : tensor -> tensor<2xindex> + %6643 = shape.cstr_broadcastable %6641, %6642 : tensor<2xindex>, tensor<2xindex> + %6644 = shape.assuming %6643 -> (tensor) { + %19688 = shape.broadcast %6641, %6642 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6640, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6639, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6645 = shape.shape_of %6644 : tensor -> tensor<2xindex> + %6646 = shape.cstr_broadcastable %6645, %6642 : tensor<2xindex>, tensor<2xindex> + %6647 = shape.assuming %6646 -> (tensor) { + %19688 = shape.broadcast %6645, %6642 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6644, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6639, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6648 = stablehlo.dot %6647, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2336 = tensor.dim %6620, %c0 : tensor + %6649 = arith.index_cast %dim_2336 : index to i64 + %from_elements_2337 = tensor.from_elements %6649, %c1_i64 : tensor<2xi64> + %6650 = stablehlo.dynamic_reshape %6620, %from_elements_2337 : (tensor, tensor<2xi64>) -> tensor + %dim_2338 = tensor.dim %6617, %c0 : tensor + %6651 = arith.index_cast %dim_2338 : index to i64 + %from_elements_2339 = tensor.from_elements %6651, %c1_i64 : tensor<2xi64> + %6652 = stablehlo.dynamic_reshape %6617, %from_elements_2339 : (tensor, tensor<2xi64>) -> tensor + %6653 = stablehlo.concatenate %6650, %6652, dim = 1 : (tensor, tensor) -> tensor + %6654 = "stablehlo.gather"(%6333, %6653) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6655 = shape.shape_of %6648 : tensor -> tensor<2xindex> + %6656 = shape.shape_of %6654 : tensor -> tensor<2xindex> + %6657 = shape.cstr_broadcastable %6655, %6656 : tensor<2xindex>, tensor<2xindex> + %6658 = shape.assuming %6657 -> (tensor) { + %19688 = shape.broadcast %6655, %6656 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6648, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6654, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6659 = shape.shape_of %6658 : tensor -> tensor<2xindex> + %6660 = stablehlo.dynamic_broadcast_in_dim %6658, %6659, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6661 = stablehlo.dynamic_broadcast_in_dim %213, %6659, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6662 = stablehlo.multiply %6660, %6661 : tensor + %dim_2340 = tensor.dim %6622, %c0 : tensor + %6663 = arith.index_cast %dim_2340 : index to i64 + %dim_2341 = tensor.dim %6658, %c0 : tensor + %6664 = arith.index_cast %dim_2341 : index to i64 + %6665 = arith.maxsi %6663, %6664 : i64 + %6666 = arith.index_cast %6665 : i64 to index + %from_elements_2342 = tensor.from_elements %6666, %c4096 : tensor<2xindex> + %6667 = stablehlo.dynamic_broadcast_in_dim %6622, %from_elements_2342, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2343 = tensor.dim %6667, %c0 : tensor + %6668 = arith.index_cast %dim_2343 : index to i64 + %from_elements_2344 = tensor.from_elements %6668, %c4096_i64 : tensor<2xi64> + %6669 = stablehlo.real_dynamic_slice %6662, %c_22, %from_elements_2344, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2345 = tensor.from_elements %6668, %c4096_i64, %c1_i64 : tensor<3xi64> + %6670 = stablehlo.dynamic_reshape %6667, %from_elements_2345 : (tensor, tensor<3xi64>) -> tensor + %6671 = stablehlo.dynamic_iota %from_elements_2345, dim = 1 : (tensor<3xi64>) -> tensor + %6672 = stablehlo.concatenate %6670, %6671, dim = 2 : (tensor, tensor) -> tensor + %6673 = "stablehlo.scatter"(%6610, %6672, %6669) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6674 = stablehlo.slice %6293 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6675 = stablehlo.reshape %6674 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6676 = stablehlo.custom_call @byteir.non_zero(%6675) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2346 = tensor.dim %6676, %c0 : tensor + %6677 = arith.index_cast %dim_2346 : index to i64 + %from_elements_2347 = tensor.from_elements %6677, %c1_i64 : tensor<2xi64> + %6678 = stablehlo.real_dynamic_slice %6676, %c_22, %from_elements_2347, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2348 = tensor.dim %6678, %c0 : tensor + %6679 = arith.index_cast %dim_2348 : index to i64 + %from_elements_2349 = tensor.from_elements %6679 : tensor<1xi64> + %6680 = stablehlo.dynamic_reshape %6678, %from_elements_2349 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2350 = tensor.from_elements %6677, %c2_i64 : tensor<2xi64> + %6681 = stablehlo.real_dynamic_slice %6676, %c_24, %from_elements_2350, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2351 = tensor.dim %6681, %c0 : tensor + %6682 = arith.index_cast %dim_2351 : index to i64 + %from_elements_2352 = tensor.from_elements %6682 : tensor<1xi64> + %6683 = stablehlo.dynamic_reshape %6681, %from_elements_2352 : (tensor, tensor<1xi64>) -> tensor + %dim_2353 = tensor.dim %6683, %c0 : tensor + %6684 = arith.index_cast %dim_2353 : index to i64 + %from_elements_2354 = tensor.from_elements %6684, %c1_i64 : tensor<2xi64> + %6685 = stablehlo.dynamic_reshape %6683, %from_elements_2354 : (tensor, tensor<2xi64>) -> tensor + %dim_2355 = tensor.dim %6685, %c0 : tensor + %6686 = arith.index_cast %dim_2355 : index to i64 + %from_elements_2356 = tensor.from_elements %c1_i64, %6686, %c4096_i64 : tensor<3xi64> + %6687 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2356, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2357 = tensor.dim %6687, %c1 : tensor<1x?x4096xi64> + %6688 = arith.index_cast %dim_2357 : index to i64 + %from_elements_2358 = tensor.from_elements %c1_i64, %6688, %c4096_i64, %c1_i64 : tensor<4xi64> + %6689 = stablehlo.dynamic_reshape %6687, %from_elements_2358 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6690 = stablehlo.dynamic_broadcast_in_dim %6685, %from_elements_2356, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2359 = tensor.dim %6690, %c1 : tensor<1x?x4096xi64> + %6691 = arith.index_cast %dim_2359 : index to i64 + %from_elements_2360 = tensor.from_elements %c1_i64, %6691, %c4096_i64, %c1_i64 : tensor<4xi64> + %6692 = stablehlo.dynamic_reshape %6690, %from_elements_2360 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6693 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2356, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2361 = tensor.dim %6693, %c1 : tensor<1x?x4096xi64> + %6694 = arith.index_cast %dim_2361 : index to i64 + %from_elements_2362 = tensor.from_elements %c1_i64, %6694, %c4096_i64, %c1_i64 : tensor<4xi64> + %6695 = stablehlo.dynamic_reshape %6693, %from_elements_2362 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6696 = stablehlo.concatenate %6689, %6692, %6695, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6697 = "stablehlo.gather"(%6304, %6696) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6698 = shape.shape_of %6697 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6699 = shape.num_elements %6698 : tensor<3xindex> -> index + %6700 = stablehlo.compute_reshape_shape %6699, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6701 = stablehlo.dynamic_reshape %6697, %6700 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6702 = stablehlo.dot %6701, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6703 = stablehlo.logistic %6702 : tensor + %6704 = shape.shape_of %6703 : tensor -> tensor<2xindex> + %6705 = shape.shape_of %6702 : tensor -> tensor<2xindex> + %6706 = shape.cstr_broadcastable %6704, %6705 : tensor<2xindex>, tensor<2xindex> + %6707 = shape.assuming %6706 -> (tensor) { + %19688 = shape.broadcast %6704, %6705 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6703, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6702, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6708 = shape.shape_of %6707 : tensor -> tensor<2xindex> + %6709 = shape.cstr_broadcastable %6708, %6705 : tensor<2xindex>, tensor<2xindex> + %6710 = shape.assuming %6709 -> (tensor) { + %19688 = shape.broadcast %6708, %6705 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6707, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6702, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6711 = stablehlo.dot %6710, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2363 = tensor.dim %6683, %c0 : tensor + %6712 = arith.index_cast %dim_2363 : index to i64 + %from_elements_2364 = tensor.from_elements %6712, %c1_i64 : tensor<2xi64> + %6713 = stablehlo.dynamic_reshape %6683, %from_elements_2364 : (tensor, tensor<2xi64>) -> tensor + %dim_2365 = tensor.dim %6680, %c0 : tensor + %6714 = arith.index_cast %dim_2365 : index to i64 + %from_elements_2366 = tensor.from_elements %6714, %c1_i64 : tensor<2xi64> + %6715 = stablehlo.dynamic_reshape %6680, %from_elements_2366 : (tensor, tensor<2xi64>) -> tensor + %6716 = stablehlo.concatenate %6713, %6715, dim = 1 : (tensor, tensor) -> tensor + %6717 = "stablehlo.gather"(%6333, %6716) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6718 = shape.shape_of %6711 : tensor -> tensor<2xindex> + %6719 = shape.shape_of %6717 : tensor -> tensor<2xindex> + %6720 = shape.cstr_broadcastable %6718, %6719 : tensor<2xindex>, tensor<2xindex> + %6721 = shape.assuming %6720 -> (tensor) { + %19688 = shape.broadcast %6718, %6719 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6711, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6717, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6722 = shape.shape_of %6721 : tensor -> tensor<2xindex> + %6723 = stablehlo.dynamic_broadcast_in_dim %6721, %6722, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6724 = stablehlo.dynamic_broadcast_in_dim %213, %6722, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6725 = stablehlo.multiply %6723, %6724 : tensor + %dim_2367 = tensor.dim %6685, %c0 : tensor + %6726 = arith.index_cast %dim_2367 : index to i64 + %dim_2368 = tensor.dim %6721, %c0 : tensor + %6727 = arith.index_cast %dim_2368 : index to i64 + %6728 = arith.maxsi %6726, %6727 : i64 + %6729 = arith.index_cast %6728 : i64 to index + %from_elements_2369 = tensor.from_elements %6729, %c4096 : tensor<2xindex> + %6730 = stablehlo.dynamic_broadcast_in_dim %6685, %from_elements_2369, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2370 = tensor.dim %6730, %c0 : tensor + %6731 = arith.index_cast %dim_2370 : index to i64 + %from_elements_2371 = tensor.from_elements %6731, %c4096_i64 : tensor<2xi64> + %6732 = stablehlo.real_dynamic_slice %6725, %c_22, %from_elements_2371, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2372 = tensor.from_elements %6731, %c4096_i64, %c1_i64 : tensor<3xi64> + %6733 = stablehlo.dynamic_reshape %6730, %from_elements_2372 : (tensor, tensor<3xi64>) -> tensor + %6734 = stablehlo.dynamic_iota %from_elements_2372, dim = 1 : (tensor<3xi64>) -> tensor + %6735 = stablehlo.concatenate %6733, %6734, dim = 2 : (tensor, tensor) -> tensor + %6736 = "stablehlo.scatter"(%6673, %6735, %6732) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6737 = stablehlo.slice %6293 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6738 = stablehlo.reshape %6737 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6739 = stablehlo.custom_call @byteir.non_zero(%6738) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2373 = tensor.dim %6739, %c0 : tensor + %6740 = arith.index_cast %dim_2373 : index to i64 + %from_elements_2374 = tensor.from_elements %6740, %c1_i64 : tensor<2xi64> + %6741 = stablehlo.real_dynamic_slice %6739, %c_22, %from_elements_2374, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2375 = tensor.dim %6741, %c0 : tensor + %6742 = arith.index_cast %dim_2375 : index to i64 + %from_elements_2376 = tensor.from_elements %6742 : tensor<1xi64> + %6743 = stablehlo.dynamic_reshape %6741, %from_elements_2376 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2377 = tensor.from_elements %6740, %c2_i64 : tensor<2xi64> + %6744 = stablehlo.real_dynamic_slice %6739, %c_24, %from_elements_2377, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2378 = tensor.dim %6744, %c0 : tensor + %6745 = arith.index_cast %dim_2378 : index to i64 + %from_elements_2379 = tensor.from_elements %6745 : tensor<1xi64> + %6746 = stablehlo.dynamic_reshape %6744, %from_elements_2379 : (tensor, tensor<1xi64>) -> tensor + %dim_2380 = tensor.dim %6746, %c0 : tensor + %6747 = arith.index_cast %dim_2380 : index to i64 + %from_elements_2381 = tensor.from_elements %6747, %c1_i64 : tensor<2xi64> + %6748 = stablehlo.dynamic_reshape %6746, %from_elements_2381 : (tensor, tensor<2xi64>) -> tensor + %dim_2382 = tensor.dim %6748, %c0 : tensor + %6749 = arith.index_cast %dim_2382 : index to i64 + %from_elements_2383 = tensor.from_elements %c1_i64, %6749, %c4096_i64 : tensor<3xi64> + %6750 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2383, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2384 = tensor.dim %6750, %c1 : tensor<1x?x4096xi64> + %6751 = arith.index_cast %dim_2384 : index to i64 + %from_elements_2385 = tensor.from_elements %c1_i64, %6751, %c4096_i64, %c1_i64 : tensor<4xi64> + %6752 = stablehlo.dynamic_reshape %6750, %from_elements_2385 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6753 = stablehlo.dynamic_broadcast_in_dim %6748, %from_elements_2383, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2386 = tensor.dim %6753, %c1 : tensor<1x?x4096xi64> + %6754 = arith.index_cast %dim_2386 : index to i64 + %from_elements_2387 = tensor.from_elements %c1_i64, %6754, %c4096_i64, %c1_i64 : tensor<4xi64> + %6755 = stablehlo.dynamic_reshape %6753, %from_elements_2387 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6756 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2383, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2388 = tensor.dim %6756, %c1 : tensor<1x?x4096xi64> + %6757 = arith.index_cast %dim_2388 : index to i64 + %from_elements_2389 = tensor.from_elements %c1_i64, %6757, %c4096_i64, %c1_i64 : tensor<4xi64> + %6758 = stablehlo.dynamic_reshape %6756, %from_elements_2389 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6759 = stablehlo.concatenate %6752, %6755, %6758, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6760 = "stablehlo.gather"(%6304, %6759) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6761 = shape.shape_of %6760 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6762 = shape.num_elements %6761 : tensor<3xindex> -> index + %6763 = stablehlo.compute_reshape_shape %6762, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6764 = stablehlo.dynamic_reshape %6760, %6763 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6765 = stablehlo.dot %6764, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6766 = stablehlo.logistic %6765 : tensor + %6767 = shape.shape_of %6766 : tensor -> tensor<2xindex> + %6768 = shape.shape_of %6765 : tensor -> tensor<2xindex> + %6769 = shape.cstr_broadcastable %6767, %6768 : tensor<2xindex>, tensor<2xindex> + %6770 = shape.assuming %6769 -> (tensor) { + %19688 = shape.broadcast %6767, %6768 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6766, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6765, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6771 = shape.shape_of %6770 : tensor -> tensor<2xindex> + %6772 = shape.cstr_broadcastable %6771, %6768 : tensor<2xindex>, tensor<2xindex> + %6773 = shape.assuming %6772 -> (tensor) { + %19688 = shape.broadcast %6771, %6768 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6770, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6765, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6774 = stablehlo.dot %6773, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2390 = tensor.dim %6746, %c0 : tensor + %6775 = arith.index_cast %dim_2390 : index to i64 + %from_elements_2391 = tensor.from_elements %6775, %c1_i64 : tensor<2xi64> + %6776 = stablehlo.dynamic_reshape %6746, %from_elements_2391 : (tensor, tensor<2xi64>) -> tensor + %dim_2392 = tensor.dim %6743, %c0 : tensor + %6777 = arith.index_cast %dim_2392 : index to i64 + %from_elements_2393 = tensor.from_elements %6777, %c1_i64 : tensor<2xi64> + %6778 = stablehlo.dynamic_reshape %6743, %from_elements_2393 : (tensor, tensor<2xi64>) -> tensor + %6779 = stablehlo.concatenate %6776, %6778, dim = 1 : (tensor, tensor) -> tensor + %6780 = "stablehlo.gather"(%6333, %6779) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6781 = shape.shape_of %6774 : tensor -> tensor<2xindex> + %6782 = shape.shape_of %6780 : tensor -> tensor<2xindex> + %6783 = shape.cstr_broadcastable %6781, %6782 : tensor<2xindex>, tensor<2xindex> + %6784 = shape.assuming %6783 -> (tensor) { + %19688 = shape.broadcast %6781, %6782 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6774, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6780, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6785 = shape.shape_of %6784 : tensor -> tensor<2xindex> + %6786 = stablehlo.dynamic_broadcast_in_dim %6784, %6785, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6787 = stablehlo.dynamic_broadcast_in_dim %213, %6785, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6788 = stablehlo.multiply %6786, %6787 : tensor + %dim_2394 = tensor.dim %6748, %c0 : tensor + %6789 = arith.index_cast %dim_2394 : index to i64 + %dim_2395 = tensor.dim %6784, %c0 : tensor + %6790 = arith.index_cast %dim_2395 : index to i64 + %6791 = arith.maxsi %6789, %6790 : i64 + %6792 = arith.index_cast %6791 : i64 to index + %from_elements_2396 = tensor.from_elements %6792, %c4096 : tensor<2xindex> + %6793 = stablehlo.dynamic_broadcast_in_dim %6748, %from_elements_2396, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2397 = tensor.dim %6793, %c0 : tensor + %6794 = arith.index_cast %dim_2397 : index to i64 + %from_elements_2398 = tensor.from_elements %6794, %c4096_i64 : tensor<2xi64> + %6795 = stablehlo.real_dynamic_slice %6788, %c_22, %from_elements_2398, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2399 = tensor.from_elements %6794, %c4096_i64, %c1_i64 : tensor<3xi64> + %6796 = stablehlo.dynamic_reshape %6793, %from_elements_2399 : (tensor, tensor<3xi64>) -> tensor + %6797 = stablehlo.dynamic_iota %from_elements_2399, dim = 1 : (tensor<3xi64>) -> tensor + %6798 = stablehlo.concatenate %6796, %6797, dim = 2 : (tensor, tensor) -> tensor + %6799 = "stablehlo.scatter"(%6736, %6798, %6795) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6800 = stablehlo.reshape %6799 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %6801 = stablehlo.add %6266, %6800 : tensor<3x1x4096xf32> + %6802 = stablehlo.broadcast_in_dim %6801, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %6803 = stablehlo.power %6802, %15 : tensor<3x1x4096xf32> + %6804 = stablehlo.reduce(%6803 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %6805 = stablehlo.reshape %6804 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %6806 = stablehlo.broadcast_in_dim %6805, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %6807 = stablehlo.divide %6806, %21 : tensor<3x1x1xf32> + %6808 = stablehlo.broadcast_in_dim %6807, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %6809 = stablehlo.add %6808, %25 : tensor<3x1x1xf32> + %6810 = stablehlo.rsqrt %6809 : tensor<3x1x1xf32> + %6811 = stablehlo.broadcast_in_dim %6810, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %6812 = stablehlo.multiply %6802, %6811 : tensor<3x1x4096xf32> + %6813 = stablehlo.broadcast_in_dim %6812, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %6814 = stablehlo.multiply %6813, %31 : tensor<3x1x4096xf32> + %6815 = stablehlo.reshape %6814 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %6816 = stablehlo.dot %6815, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %6817 = stablehlo.reshape %6816 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %6818 = stablehlo.dot %6815, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %6819 = stablehlo.reshape %6818 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %6820 = stablehlo.reshape %6817 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %6821 = stablehlo.transpose %6820, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %6822 = stablehlo.reshape %6819 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %6823 = stablehlo.transpose %6822, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %6824 = stablehlo.slice %arg22 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %6825 = stablehlo.slice %arg23 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %6826 = "stablehlo.gather"(%6824, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %6827 = stablehlo.reshape %6826 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %6828 = "stablehlo.gather"(%6825, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %6829 = stablehlo.reshape %6828 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %6830 = stablehlo.broadcast_in_dim %6821, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %6831 = stablehlo.broadcast_in_dim %6827, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %6832 = stablehlo.multiply %6830, %6831 : tensor<3x32x1x128xf32> + %6833 = stablehlo.slice %6821 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %6834 = stablehlo.slice %6821 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %6835 = stablehlo.negate %6834 : tensor<3x32x1x64xf32> + %6836 = stablehlo.concatenate %6835, %6833, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %6837 = stablehlo.broadcast_in_dim %6836, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %6838 = stablehlo.broadcast_in_dim %6829, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %6839 = stablehlo.multiply %6837, %6838 : tensor<3x32x1x128xf32> + %6840 = stablehlo.add %6832, %6839 : tensor<3x32x1x128xf32> + %6841 = stablehlo.broadcast_in_dim %6823, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %6842 = stablehlo.broadcast_in_dim %6827, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %6843 = stablehlo.multiply %6841, %6842 : tensor<3x8x1x128xf32> + %6844 = stablehlo.slice %6823 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %6845 = stablehlo.slice %6823 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %6846 = stablehlo.negate %6845 : tensor<3x8x1x64xf32> + %6847 = stablehlo.concatenate %6846, %6844, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %6848 = stablehlo.broadcast_in_dim %6847, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %6849 = stablehlo.broadcast_in_dim %6829, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %6850 = stablehlo.multiply %6848, %6849 : tensor<3x8x1x128xf32> + %6851 = stablehlo.add %6843, %6850 : tensor<3x8x1x128xf32> + %6852 = stablehlo.concatenate %arg87, %6851, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %6853 = stablehlo.concatenate %arg88, %6823, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %6854 = stablehlo.reshape %6852 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %6855 = stablehlo.broadcast_in_dim %6854, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %6856 = stablehlo.reshape %6855 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %6857 = stablehlo.reshape %6853 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %6858 = stablehlo.broadcast_in_dim %6857, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %6859 = stablehlo.reshape %6858 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %6860 = stablehlo.transpose %6856, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %6861 = stablehlo.reshape %6840 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %6862 = stablehlo.reshape %6860 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %6863 = stablehlo.broadcast_in_dim %6862, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %6864 = stablehlo.dot_general %6861, %6863, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %6865 = stablehlo.reshape %6864 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %6866 = stablehlo.broadcast_in_dim %6865, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %6867 = stablehlo.divide %6866, %89 : tensor<3x32x1x8xf32> + %6868 = stablehlo.custom_call @byteir.softmax(%6867) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %6869 = stablehlo.reshape %6868 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %6870 = stablehlo.reshape %6859 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %6871 = stablehlo.broadcast_in_dim %6870, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %6872 = stablehlo.dot_general %6869, %6871, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %6873 = stablehlo.reshape %6872 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %6874 = stablehlo.transpose %6873, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %6875 = stablehlo.reshape %6874 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %6876 = stablehlo.reshape %6875 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %6877 = stablehlo.dot %6876, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %6878 = stablehlo.reshape %6877 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %6879 = stablehlo.add %6801, %6878 : tensor<3x1x4096xf32> + %6880 = stablehlo.broadcast_in_dim %6879, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %6881 = stablehlo.power %6880, %15 : tensor<3x1x4096xf32> + %6882 = stablehlo.reduce(%6881 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %6883 = stablehlo.reshape %6882 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %6884 = stablehlo.broadcast_in_dim %6883, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %6885 = stablehlo.divide %6884, %21 : tensor<3x1x1xf32> + %6886 = stablehlo.broadcast_in_dim %6885, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %6887 = stablehlo.add %6886, %25 : tensor<3x1x1xf32> + %6888 = stablehlo.rsqrt %6887 : tensor<3x1x1xf32> + %6889 = stablehlo.broadcast_in_dim %6888, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %6890 = stablehlo.multiply %6880, %6889 : tensor<3x1x4096xf32> + %6891 = stablehlo.broadcast_in_dim %6890, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %6892 = stablehlo.multiply %6891, %31 : tensor<3x1x4096xf32> + %6893 = stablehlo.reshape %6892 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %6894 = stablehlo.dot %6893, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %6895 = stablehlo.custom_call @byteir.softmax(%6894) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %6896:2 = stablehlo.custom_call @byteir.top_k(%6895) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %6897 = stablehlo.reduce(%6896#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %6898 = stablehlo.reshape %6897 : (tensor<3xf32>) -> tensor<3x1xf32> + %6899 = stablehlo.broadcast_in_dim %6896#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %6900 = stablehlo.broadcast_in_dim %6898, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %6901 = stablehlo.divide %6899, %6900 : tensor<3x2xf32> + %6902 = stablehlo.reshape %6896#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %6903 = stablehlo.broadcast_in_dim %6902, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %6904 = stablehlo.compare EQ, %6903, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %6905 = stablehlo.convert %6904 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %6906 = stablehlo.transpose %6905, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %6907 = stablehlo.slice %6906 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6908 = stablehlo.reshape %6907 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6909 = stablehlo.custom_call @byteir.non_zero(%6908) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2400 = tensor.dim %6909, %c0 : tensor + %6910 = arith.index_cast %dim_2400 : index to i64 + %from_elements_2401 = tensor.from_elements %6910, %c1_i64 : tensor<2xi64> + %6911 = stablehlo.real_dynamic_slice %6909, %c_22, %from_elements_2401, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2402 = tensor.dim %6911, %c0 : tensor + %6912 = arith.index_cast %dim_2402 : index to i64 + %from_elements_2403 = tensor.from_elements %6912 : tensor<1xi64> + %6913 = stablehlo.dynamic_reshape %6911, %from_elements_2403 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2404 = tensor.from_elements %6910, %c2_i64 : tensor<2xi64> + %6914 = stablehlo.real_dynamic_slice %6909, %c_24, %from_elements_2404, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2405 = tensor.dim %6914, %c0 : tensor + %6915 = arith.index_cast %dim_2405 : index to i64 + %from_elements_2406 = tensor.from_elements %6915 : tensor<1xi64> + %6916 = stablehlo.dynamic_reshape %6914, %from_elements_2406 : (tensor, tensor<1xi64>) -> tensor + %6917 = stablehlo.reshape %6893 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_2407 = tensor.dim %6916, %c0 : tensor + %6918 = arith.index_cast %dim_2407 : index to i64 + %from_elements_2408 = tensor.from_elements %6918, %c1_i64 : tensor<2xi64> + %6919 = stablehlo.dynamic_reshape %6916, %from_elements_2408 : (tensor, tensor<2xi64>) -> tensor + %dim_2409 = tensor.dim %6919, %c0 : tensor + %6920 = arith.index_cast %dim_2409 : index to i64 + %from_elements_2410 = tensor.from_elements %c1_i64, %6920, %c4096_i64 : tensor<3xi64> + %6921 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2410, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2411 = tensor.dim %6921, %c1 : tensor<1x?x4096xi64> + %6922 = arith.index_cast %dim_2411 : index to i64 + %from_elements_2412 = tensor.from_elements %c1_i64, %6922, %c4096_i64, %c1_i64 : tensor<4xi64> + %6923 = stablehlo.dynamic_reshape %6921, %from_elements_2412 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6924 = stablehlo.dynamic_broadcast_in_dim %6919, %from_elements_2410, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2413 = tensor.dim %6924, %c1 : tensor<1x?x4096xi64> + %6925 = arith.index_cast %dim_2413 : index to i64 + %from_elements_2414 = tensor.from_elements %c1_i64, %6925, %c4096_i64, %c1_i64 : tensor<4xi64> + %6926 = stablehlo.dynamic_reshape %6924, %from_elements_2414 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6927 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2410, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2415 = tensor.dim %6927, %c1 : tensor<1x?x4096xi64> + %6928 = arith.index_cast %dim_2415 : index to i64 + %from_elements_2416 = tensor.from_elements %c1_i64, %6928, %c4096_i64, %c1_i64 : tensor<4xi64> + %6929 = stablehlo.dynamic_reshape %6927, %from_elements_2416 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6930 = stablehlo.concatenate %6923, %6926, %6929, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6931 = "stablehlo.gather"(%6917, %6930) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6932 = shape.shape_of %6931 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6933 = shape.num_elements %6932 : tensor<3xindex> -> index + %6934 = stablehlo.compute_reshape_shape %6933, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6935 = stablehlo.dynamic_reshape %6931, %6934 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %6936 = stablehlo.dot %6935, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %6937 = stablehlo.logistic %6936 : tensor + %6938 = shape.shape_of %6937 : tensor -> tensor<2xindex> + %6939 = shape.shape_of %6936 : tensor -> tensor<2xindex> + %6940 = shape.cstr_broadcastable %6938, %6939 : tensor<2xindex>, tensor<2xindex> + %6941 = shape.assuming %6940 -> (tensor) { + %19688 = shape.broadcast %6938, %6939 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6937, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6936, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6942 = shape.shape_of %6941 : tensor -> tensor<2xindex> + %6943 = shape.cstr_broadcastable %6942, %6939 : tensor<2xindex>, tensor<2xindex> + %6944 = shape.assuming %6943 -> (tensor) { + %19688 = shape.broadcast %6942, %6939 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6941, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6936, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6945 = stablehlo.dot %6944, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %6946 = stablehlo.reshape %6901 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_2417 = tensor.dim %6916, %c0 : tensor + %6947 = arith.index_cast %dim_2417 : index to i64 + %from_elements_2418 = tensor.from_elements %6947, %c1_i64 : tensor<2xi64> + %6948 = stablehlo.dynamic_reshape %6916, %from_elements_2418 : (tensor, tensor<2xi64>) -> tensor + %dim_2419 = tensor.dim %6913, %c0 : tensor + %6949 = arith.index_cast %dim_2419 : index to i64 + %from_elements_2420 = tensor.from_elements %6949, %c1_i64 : tensor<2xi64> + %6950 = stablehlo.dynamic_reshape %6913, %from_elements_2420 : (tensor, tensor<2xi64>) -> tensor + %6951 = stablehlo.concatenate %6948, %6950, dim = 1 : (tensor, tensor) -> tensor + %6952 = "stablehlo.gather"(%6946, %6951) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %6953 = shape.shape_of %6945 : tensor -> tensor<2xindex> + %6954 = shape.shape_of %6952 : tensor -> tensor<2xindex> + %6955 = shape.cstr_broadcastable %6953, %6954 : tensor<2xindex>, tensor<2xindex> + %6956 = shape.assuming %6955 -> (tensor) { + %19688 = shape.broadcast %6953, %6954 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %6945, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %6952, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %6957 = shape.shape_of %6956 : tensor -> tensor<2xindex> + %6958 = stablehlo.dynamic_broadcast_in_dim %6956, %6957, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %6959 = stablehlo.dynamic_broadcast_in_dim %213, %6957, dims = [] : (tensor, tensor<2xindex>) -> tensor + %6960 = stablehlo.multiply %6958, %6959 : tensor + %dim_2421 = tensor.dim %6919, %c0 : tensor + %6961 = arith.index_cast %dim_2421 : index to i64 + %dim_2422 = tensor.dim %6956, %c0 : tensor + %6962 = arith.index_cast %dim_2422 : index to i64 + %6963 = arith.maxsi %6961, %6962 : i64 + %6964 = arith.index_cast %6963 : i64 to index + %from_elements_2423 = tensor.from_elements %6964, %c4096 : tensor<2xindex> + %6965 = stablehlo.dynamic_broadcast_in_dim %6919, %from_elements_2423, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2424 = tensor.dim %6965, %c0 : tensor + %6966 = arith.index_cast %dim_2424 : index to i64 + %from_elements_2425 = tensor.from_elements %6966, %c4096_i64 : tensor<2xi64> + %6967 = stablehlo.real_dynamic_slice %6960, %c_22, %from_elements_2425, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2426 = tensor.from_elements %6966, %c4096_i64, %c1_i64 : tensor<3xi64> + %6968 = stablehlo.dynamic_reshape %6965, %from_elements_2426 : (tensor, tensor<3xi64>) -> tensor + %6969 = stablehlo.dynamic_iota %from_elements_2426, dim = 1 : (tensor<3xi64>) -> tensor + %6970 = stablehlo.concatenate %6968, %6969, dim = 2 : (tensor, tensor) -> tensor + %6971 = "stablehlo.scatter"(%cst_2, %6970, %6967) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %6972 = stablehlo.slice %6906 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %6973 = stablehlo.reshape %6972 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %6974 = stablehlo.custom_call @byteir.non_zero(%6973) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2427 = tensor.dim %6974, %c0 : tensor + %6975 = arith.index_cast %dim_2427 : index to i64 + %from_elements_2428 = tensor.from_elements %6975, %c1_i64 : tensor<2xi64> + %6976 = stablehlo.real_dynamic_slice %6974, %c_22, %from_elements_2428, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2429 = tensor.dim %6976, %c0 : tensor + %6977 = arith.index_cast %dim_2429 : index to i64 + %from_elements_2430 = tensor.from_elements %6977 : tensor<1xi64> + %6978 = stablehlo.dynamic_reshape %6976, %from_elements_2430 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2431 = tensor.from_elements %6975, %c2_i64 : tensor<2xi64> + %6979 = stablehlo.real_dynamic_slice %6974, %c_24, %from_elements_2431, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2432 = tensor.dim %6979, %c0 : tensor + %6980 = arith.index_cast %dim_2432 : index to i64 + %from_elements_2433 = tensor.from_elements %6980 : tensor<1xi64> + %6981 = stablehlo.dynamic_reshape %6979, %from_elements_2433 : (tensor, tensor<1xi64>) -> tensor + %dim_2434 = tensor.dim %6981, %c0 : tensor + %6982 = arith.index_cast %dim_2434 : index to i64 + %from_elements_2435 = tensor.from_elements %6982, %c1_i64 : tensor<2xi64> + %6983 = stablehlo.dynamic_reshape %6981, %from_elements_2435 : (tensor, tensor<2xi64>) -> tensor + %dim_2436 = tensor.dim %6983, %c0 : tensor + %6984 = arith.index_cast %dim_2436 : index to i64 + %from_elements_2437 = tensor.from_elements %c1_i64, %6984, %c4096_i64 : tensor<3xi64> + %6985 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2437, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2438 = tensor.dim %6985, %c1 : tensor<1x?x4096xi64> + %6986 = arith.index_cast %dim_2438 : index to i64 + %from_elements_2439 = tensor.from_elements %c1_i64, %6986, %c4096_i64, %c1_i64 : tensor<4xi64> + %6987 = stablehlo.dynamic_reshape %6985, %from_elements_2439 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6988 = stablehlo.dynamic_broadcast_in_dim %6983, %from_elements_2437, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2440 = tensor.dim %6988, %c1 : tensor<1x?x4096xi64> + %6989 = arith.index_cast %dim_2440 : index to i64 + %from_elements_2441 = tensor.from_elements %c1_i64, %6989, %c4096_i64, %c1_i64 : tensor<4xi64> + %6990 = stablehlo.dynamic_reshape %6988, %from_elements_2441 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6991 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2437, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2442 = tensor.dim %6991, %c1 : tensor<1x?x4096xi64> + %6992 = arith.index_cast %dim_2442 : index to i64 + %from_elements_2443 = tensor.from_elements %c1_i64, %6992, %c4096_i64, %c1_i64 : tensor<4xi64> + %6993 = stablehlo.dynamic_reshape %6991, %from_elements_2443 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %6994 = stablehlo.concatenate %6987, %6990, %6993, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %6995 = "stablehlo.gather"(%6917, %6994) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %6996 = shape.shape_of %6995 : tensor<1x?x4096xf32> -> tensor<3xindex> + %6997 = shape.num_elements %6996 : tensor<3xindex> -> index + %6998 = stablehlo.compute_reshape_shape %6997, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %6999 = stablehlo.dynamic_reshape %6995, %6998 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7000 = stablehlo.dot %6999, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7001 = stablehlo.logistic %7000 : tensor + %7002 = shape.shape_of %7001 : tensor -> tensor<2xindex> + %7003 = shape.shape_of %7000 : tensor -> tensor<2xindex> + %7004 = shape.cstr_broadcastable %7002, %7003 : tensor<2xindex>, tensor<2xindex> + %7005 = shape.assuming %7004 -> (tensor) { + %19688 = shape.broadcast %7002, %7003 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7001, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7000, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7006 = shape.shape_of %7005 : tensor -> tensor<2xindex> + %7007 = shape.cstr_broadcastable %7006, %7003 : tensor<2xindex>, tensor<2xindex> + %7008 = shape.assuming %7007 -> (tensor) { + %19688 = shape.broadcast %7006, %7003 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7005, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7000, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7009 = stablehlo.dot %7008, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2444 = tensor.dim %6981, %c0 : tensor + %7010 = arith.index_cast %dim_2444 : index to i64 + %from_elements_2445 = tensor.from_elements %7010, %c1_i64 : tensor<2xi64> + %7011 = stablehlo.dynamic_reshape %6981, %from_elements_2445 : (tensor, tensor<2xi64>) -> tensor + %dim_2446 = tensor.dim %6978, %c0 : tensor + %7012 = arith.index_cast %dim_2446 : index to i64 + %from_elements_2447 = tensor.from_elements %7012, %c1_i64 : tensor<2xi64> + %7013 = stablehlo.dynamic_reshape %6978, %from_elements_2447 : (tensor, tensor<2xi64>) -> tensor + %7014 = stablehlo.concatenate %7011, %7013, dim = 1 : (tensor, tensor) -> tensor + %7015 = "stablehlo.gather"(%6946, %7014) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7016 = shape.shape_of %7009 : tensor -> tensor<2xindex> + %7017 = shape.shape_of %7015 : tensor -> tensor<2xindex> + %7018 = shape.cstr_broadcastable %7016, %7017 : tensor<2xindex>, tensor<2xindex> + %7019 = shape.assuming %7018 -> (tensor) { + %19688 = shape.broadcast %7016, %7017 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7009, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7015, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7020 = shape.shape_of %7019 : tensor -> tensor<2xindex> + %7021 = stablehlo.dynamic_broadcast_in_dim %7019, %7020, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7022 = stablehlo.dynamic_broadcast_in_dim %213, %7020, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7023 = stablehlo.multiply %7021, %7022 : tensor + %dim_2448 = tensor.dim %6983, %c0 : tensor + %7024 = arith.index_cast %dim_2448 : index to i64 + %dim_2449 = tensor.dim %7019, %c0 : tensor + %7025 = arith.index_cast %dim_2449 : index to i64 + %7026 = arith.maxsi %7024, %7025 : i64 + %7027 = arith.index_cast %7026 : i64 to index + %from_elements_2450 = tensor.from_elements %7027, %c4096 : tensor<2xindex> + %7028 = stablehlo.dynamic_broadcast_in_dim %6983, %from_elements_2450, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2451 = tensor.dim %7028, %c0 : tensor + %7029 = arith.index_cast %dim_2451 : index to i64 + %from_elements_2452 = tensor.from_elements %7029, %c4096_i64 : tensor<2xi64> + %7030 = stablehlo.real_dynamic_slice %7023, %c_22, %from_elements_2452, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2453 = tensor.from_elements %7029, %c4096_i64, %c1_i64 : tensor<3xi64> + %7031 = stablehlo.dynamic_reshape %7028, %from_elements_2453 : (tensor, tensor<3xi64>) -> tensor + %7032 = stablehlo.dynamic_iota %from_elements_2453, dim = 1 : (tensor<3xi64>) -> tensor + %7033 = stablehlo.concatenate %7031, %7032, dim = 2 : (tensor, tensor) -> tensor + %7034 = "stablehlo.scatter"(%6971, %7033, %7030) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7035 = stablehlo.slice %6906 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7036 = stablehlo.reshape %7035 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7037 = stablehlo.custom_call @byteir.non_zero(%7036) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2454 = tensor.dim %7037, %c0 : tensor + %7038 = arith.index_cast %dim_2454 : index to i64 + %from_elements_2455 = tensor.from_elements %7038, %c1_i64 : tensor<2xi64> + %7039 = stablehlo.real_dynamic_slice %7037, %c_22, %from_elements_2455, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2456 = tensor.dim %7039, %c0 : tensor + %7040 = arith.index_cast %dim_2456 : index to i64 + %from_elements_2457 = tensor.from_elements %7040 : tensor<1xi64> + %7041 = stablehlo.dynamic_reshape %7039, %from_elements_2457 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2458 = tensor.from_elements %7038, %c2_i64 : tensor<2xi64> + %7042 = stablehlo.real_dynamic_slice %7037, %c_24, %from_elements_2458, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2459 = tensor.dim %7042, %c0 : tensor + %7043 = arith.index_cast %dim_2459 : index to i64 + %from_elements_2460 = tensor.from_elements %7043 : tensor<1xi64> + %7044 = stablehlo.dynamic_reshape %7042, %from_elements_2460 : (tensor, tensor<1xi64>) -> tensor + %dim_2461 = tensor.dim %7044, %c0 : tensor + %7045 = arith.index_cast %dim_2461 : index to i64 + %from_elements_2462 = tensor.from_elements %7045, %c1_i64 : tensor<2xi64> + %7046 = stablehlo.dynamic_reshape %7044, %from_elements_2462 : (tensor, tensor<2xi64>) -> tensor + %dim_2463 = tensor.dim %7046, %c0 : tensor + %7047 = arith.index_cast %dim_2463 : index to i64 + %from_elements_2464 = tensor.from_elements %c1_i64, %7047, %c4096_i64 : tensor<3xi64> + %7048 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2464, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2465 = tensor.dim %7048, %c1 : tensor<1x?x4096xi64> + %7049 = arith.index_cast %dim_2465 : index to i64 + %from_elements_2466 = tensor.from_elements %c1_i64, %7049, %c4096_i64, %c1_i64 : tensor<4xi64> + %7050 = stablehlo.dynamic_reshape %7048, %from_elements_2466 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7051 = stablehlo.dynamic_broadcast_in_dim %7046, %from_elements_2464, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2467 = tensor.dim %7051, %c1 : tensor<1x?x4096xi64> + %7052 = arith.index_cast %dim_2467 : index to i64 + %from_elements_2468 = tensor.from_elements %c1_i64, %7052, %c4096_i64, %c1_i64 : tensor<4xi64> + %7053 = stablehlo.dynamic_reshape %7051, %from_elements_2468 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7054 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2464, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2469 = tensor.dim %7054, %c1 : tensor<1x?x4096xi64> + %7055 = arith.index_cast %dim_2469 : index to i64 + %from_elements_2470 = tensor.from_elements %c1_i64, %7055, %c4096_i64, %c1_i64 : tensor<4xi64> + %7056 = stablehlo.dynamic_reshape %7054, %from_elements_2470 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7057 = stablehlo.concatenate %7050, %7053, %7056, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7058 = "stablehlo.gather"(%6917, %7057) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7059 = shape.shape_of %7058 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7060 = shape.num_elements %7059 : tensor<3xindex> -> index + %7061 = stablehlo.compute_reshape_shape %7060, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7062 = stablehlo.dynamic_reshape %7058, %7061 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7063 = stablehlo.dot %7062, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7064 = stablehlo.logistic %7063 : tensor + %7065 = shape.shape_of %7064 : tensor -> tensor<2xindex> + %7066 = shape.shape_of %7063 : tensor -> tensor<2xindex> + %7067 = shape.cstr_broadcastable %7065, %7066 : tensor<2xindex>, tensor<2xindex> + %7068 = shape.assuming %7067 -> (tensor) { + %19688 = shape.broadcast %7065, %7066 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7064, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7063, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7069 = shape.shape_of %7068 : tensor -> tensor<2xindex> + %7070 = shape.cstr_broadcastable %7069, %7066 : tensor<2xindex>, tensor<2xindex> + %7071 = shape.assuming %7070 -> (tensor) { + %19688 = shape.broadcast %7069, %7066 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7068, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7063, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7072 = stablehlo.dot %7071, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2471 = tensor.dim %7044, %c0 : tensor + %7073 = arith.index_cast %dim_2471 : index to i64 + %from_elements_2472 = tensor.from_elements %7073, %c1_i64 : tensor<2xi64> + %7074 = stablehlo.dynamic_reshape %7044, %from_elements_2472 : (tensor, tensor<2xi64>) -> tensor + %dim_2473 = tensor.dim %7041, %c0 : tensor + %7075 = arith.index_cast %dim_2473 : index to i64 + %from_elements_2474 = tensor.from_elements %7075, %c1_i64 : tensor<2xi64> + %7076 = stablehlo.dynamic_reshape %7041, %from_elements_2474 : (tensor, tensor<2xi64>) -> tensor + %7077 = stablehlo.concatenate %7074, %7076, dim = 1 : (tensor, tensor) -> tensor + %7078 = "stablehlo.gather"(%6946, %7077) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7079 = shape.shape_of %7072 : tensor -> tensor<2xindex> + %7080 = shape.shape_of %7078 : tensor -> tensor<2xindex> + %7081 = shape.cstr_broadcastable %7079, %7080 : tensor<2xindex>, tensor<2xindex> + %7082 = shape.assuming %7081 -> (tensor) { + %19688 = shape.broadcast %7079, %7080 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7072, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7078, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7083 = shape.shape_of %7082 : tensor -> tensor<2xindex> + %7084 = stablehlo.dynamic_broadcast_in_dim %7082, %7083, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7085 = stablehlo.dynamic_broadcast_in_dim %213, %7083, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7086 = stablehlo.multiply %7084, %7085 : tensor + %dim_2475 = tensor.dim %7046, %c0 : tensor + %7087 = arith.index_cast %dim_2475 : index to i64 + %dim_2476 = tensor.dim %7082, %c0 : tensor + %7088 = arith.index_cast %dim_2476 : index to i64 + %7089 = arith.maxsi %7087, %7088 : i64 + %7090 = arith.index_cast %7089 : i64 to index + %from_elements_2477 = tensor.from_elements %7090, %c4096 : tensor<2xindex> + %7091 = stablehlo.dynamic_broadcast_in_dim %7046, %from_elements_2477, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2478 = tensor.dim %7091, %c0 : tensor + %7092 = arith.index_cast %dim_2478 : index to i64 + %from_elements_2479 = tensor.from_elements %7092, %c4096_i64 : tensor<2xi64> + %7093 = stablehlo.real_dynamic_slice %7086, %c_22, %from_elements_2479, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2480 = tensor.from_elements %7092, %c4096_i64, %c1_i64 : tensor<3xi64> + %7094 = stablehlo.dynamic_reshape %7091, %from_elements_2480 : (tensor, tensor<3xi64>) -> tensor + %7095 = stablehlo.dynamic_iota %from_elements_2480, dim = 1 : (tensor<3xi64>) -> tensor + %7096 = stablehlo.concatenate %7094, %7095, dim = 2 : (tensor, tensor) -> tensor + %7097 = "stablehlo.scatter"(%7034, %7096, %7093) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7098 = stablehlo.slice %6906 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7099 = stablehlo.reshape %7098 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7100 = stablehlo.custom_call @byteir.non_zero(%7099) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2481 = tensor.dim %7100, %c0 : tensor + %7101 = arith.index_cast %dim_2481 : index to i64 + %from_elements_2482 = tensor.from_elements %7101, %c1_i64 : tensor<2xi64> + %7102 = stablehlo.real_dynamic_slice %7100, %c_22, %from_elements_2482, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2483 = tensor.dim %7102, %c0 : tensor + %7103 = arith.index_cast %dim_2483 : index to i64 + %from_elements_2484 = tensor.from_elements %7103 : tensor<1xi64> + %7104 = stablehlo.dynamic_reshape %7102, %from_elements_2484 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2485 = tensor.from_elements %7101, %c2_i64 : tensor<2xi64> + %7105 = stablehlo.real_dynamic_slice %7100, %c_24, %from_elements_2485, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2486 = tensor.dim %7105, %c0 : tensor + %7106 = arith.index_cast %dim_2486 : index to i64 + %from_elements_2487 = tensor.from_elements %7106 : tensor<1xi64> + %7107 = stablehlo.dynamic_reshape %7105, %from_elements_2487 : (tensor, tensor<1xi64>) -> tensor + %dim_2488 = tensor.dim %7107, %c0 : tensor + %7108 = arith.index_cast %dim_2488 : index to i64 + %from_elements_2489 = tensor.from_elements %7108, %c1_i64 : tensor<2xi64> + %7109 = stablehlo.dynamic_reshape %7107, %from_elements_2489 : (tensor, tensor<2xi64>) -> tensor + %dim_2490 = tensor.dim %7109, %c0 : tensor + %7110 = arith.index_cast %dim_2490 : index to i64 + %from_elements_2491 = tensor.from_elements %c1_i64, %7110, %c4096_i64 : tensor<3xi64> + %7111 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2491, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2492 = tensor.dim %7111, %c1 : tensor<1x?x4096xi64> + %7112 = arith.index_cast %dim_2492 : index to i64 + %from_elements_2493 = tensor.from_elements %c1_i64, %7112, %c4096_i64, %c1_i64 : tensor<4xi64> + %7113 = stablehlo.dynamic_reshape %7111, %from_elements_2493 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7114 = stablehlo.dynamic_broadcast_in_dim %7109, %from_elements_2491, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2494 = tensor.dim %7114, %c1 : tensor<1x?x4096xi64> + %7115 = arith.index_cast %dim_2494 : index to i64 + %from_elements_2495 = tensor.from_elements %c1_i64, %7115, %c4096_i64, %c1_i64 : tensor<4xi64> + %7116 = stablehlo.dynamic_reshape %7114, %from_elements_2495 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7117 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2491, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2496 = tensor.dim %7117, %c1 : tensor<1x?x4096xi64> + %7118 = arith.index_cast %dim_2496 : index to i64 + %from_elements_2497 = tensor.from_elements %c1_i64, %7118, %c4096_i64, %c1_i64 : tensor<4xi64> + %7119 = stablehlo.dynamic_reshape %7117, %from_elements_2497 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7120 = stablehlo.concatenate %7113, %7116, %7119, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7121 = "stablehlo.gather"(%6917, %7120) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7122 = shape.shape_of %7121 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7123 = shape.num_elements %7122 : tensor<3xindex> -> index + %7124 = stablehlo.compute_reshape_shape %7123, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7125 = stablehlo.dynamic_reshape %7121, %7124 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7126 = stablehlo.dot %7125, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7127 = stablehlo.logistic %7126 : tensor + %7128 = shape.shape_of %7127 : tensor -> tensor<2xindex> + %7129 = shape.shape_of %7126 : tensor -> tensor<2xindex> + %7130 = shape.cstr_broadcastable %7128, %7129 : tensor<2xindex>, tensor<2xindex> + %7131 = shape.assuming %7130 -> (tensor) { + %19688 = shape.broadcast %7128, %7129 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7127, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7126, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7132 = shape.shape_of %7131 : tensor -> tensor<2xindex> + %7133 = shape.cstr_broadcastable %7132, %7129 : tensor<2xindex>, tensor<2xindex> + %7134 = shape.assuming %7133 -> (tensor) { + %19688 = shape.broadcast %7132, %7129 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7131, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7126, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7135 = stablehlo.dot %7134, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2498 = tensor.dim %7107, %c0 : tensor + %7136 = arith.index_cast %dim_2498 : index to i64 + %from_elements_2499 = tensor.from_elements %7136, %c1_i64 : tensor<2xi64> + %7137 = stablehlo.dynamic_reshape %7107, %from_elements_2499 : (tensor, tensor<2xi64>) -> tensor + %dim_2500 = tensor.dim %7104, %c0 : tensor + %7138 = arith.index_cast %dim_2500 : index to i64 + %from_elements_2501 = tensor.from_elements %7138, %c1_i64 : tensor<2xi64> + %7139 = stablehlo.dynamic_reshape %7104, %from_elements_2501 : (tensor, tensor<2xi64>) -> tensor + %7140 = stablehlo.concatenate %7137, %7139, dim = 1 : (tensor, tensor) -> tensor + %7141 = "stablehlo.gather"(%6946, %7140) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7142 = shape.shape_of %7135 : tensor -> tensor<2xindex> + %7143 = shape.shape_of %7141 : tensor -> tensor<2xindex> + %7144 = shape.cstr_broadcastable %7142, %7143 : tensor<2xindex>, tensor<2xindex> + %7145 = shape.assuming %7144 -> (tensor) { + %19688 = shape.broadcast %7142, %7143 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7135, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7141, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7146 = shape.shape_of %7145 : tensor -> tensor<2xindex> + %7147 = stablehlo.dynamic_broadcast_in_dim %7145, %7146, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7148 = stablehlo.dynamic_broadcast_in_dim %213, %7146, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7149 = stablehlo.multiply %7147, %7148 : tensor + %dim_2502 = tensor.dim %7109, %c0 : tensor + %7150 = arith.index_cast %dim_2502 : index to i64 + %dim_2503 = tensor.dim %7145, %c0 : tensor + %7151 = arith.index_cast %dim_2503 : index to i64 + %7152 = arith.maxsi %7150, %7151 : i64 + %7153 = arith.index_cast %7152 : i64 to index + %from_elements_2504 = tensor.from_elements %7153, %c4096 : tensor<2xindex> + %7154 = stablehlo.dynamic_broadcast_in_dim %7109, %from_elements_2504, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2505 = tensor.dim %7154, %c0 : tensor + %7155 = arith.index_cast %dim_2505 : index to i64 + %from_elements_2506 = tensor.from_elements %7155, %c4096_i64 : tensor<2xi64> + %7156 = stablehlo.real_dynamic_slice %7149, %c_22, %from_elements_2506, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2507 = tensor.from_elements %7155, %c4096_i64, %c1_i64 : tensor<3xi64> + %7157 = stablehlo.dynamic_reshape %7154, %from_elements_2507 : (tensor, tensor<3xi64>) -> tensor + %7158 = stablehlo.dynamic_iota %from_elements_2507, dim = 1 : (tensor<3xi64>) -> tensor + %7159 = stablehlo.concatenate %7157, %7158, dim = 2 : (tensor, tensor) -> tensor + %7160 = "stablehlo.scatter"(%7097, %7159, %7156) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7161 = stablehlo.slice %6906 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7162 = stablehlo.reshape %7161 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7163 = stablehlo.custom_call @byteir.non_zero(%7162) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2508 = tensor.dim %7163, %c0 : tensor + %7164 = arith.index_cast %dim_2508 : index to i64 + %from_elements_2509 = tensor.from_elements %7164, %c1_i64 : tensor<2xi64> + %7165 = stablehlo.real_dynamic_slice %7163, %c_22, %from_elements_2509, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2510 = tensor.dim %7165, %c0 : tensor + %7166 = arith.index_cast %dim_2510 : index to i64 + %from_elements_2511 = tensor.from_elements %7166 : tensor<1xi64> + %7167 = stablehlo.dynamic_reshape %7165, %from_elements_2511 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2512 = tensor.from_elements %7164, %c2_i64 : tensor<2xi64> + %7168 = stablehlo.real_dynamic_slice %7163, %c_24, %from_elements_2512, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2513 = tensor.dim %7168, %c0 : tensor + %7169 = arith.index_cast %dim_2513 : index to i64 + %from_elements_2514 = tensor.from_elements %7169 : tensor<1xi64> + %7170 = stablehlo.dynamic_reshape %7168, %from_elements_2514 : (tensor, tensor<1xi64>) -> tensor + %dim_2515 = tensor.dim %7170, %c0 : tensor + %7171 = arith.index_cast %dim_2515 : index to i64 + %from_elements_2516 = tensor.from_elements %7171, %c1_i64 : tensor<2xi64> + %7172 = stablehlo.dynamic_reshape %7170, %from_elements_2516 : (tensor, tensor<2xi64>) -> tensor + %dim_2517 = tensor.dim %7172, %c0 : tensor + %7173 = arith.index_cast %dim_2517 : index to i64 + %from_elements_2518 = tensor.from_elements %c1_i64, %7173, %c4096_i64 : tensor<3xi64> + %7174 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2518, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2519 = tensor.dim %7174, %c1 : tensor<1x?x4096xi64> + %7175 = arith.index_cast %dim_2519 : index to i64 + %from_elements_2520 = tensor.from_elements %c1_i64, %7175, %c4096_i64, %c1_i64 : tensor<4xi64> + %7176 = stablehlo.dynamic_reshape %7174, %from_elements_2520 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7177 = stablehlo.dynamic_broadcast_in_dim %7172, %from_elements_2518, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2521 = tensor.dim %7177, %c1 : tensor<1x?x4096xi64> + %7178 = arith.index_cast %dim_2521 : index to i64 + %from_elements_2522 = tensor.from_elements %c1_i64, %7178, %c4096_i64, %c1_i64 : tensor<4xi64> + %7179 = stablehlo.dynamic_reshape %7177, %from_elements_2522 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7180 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2518, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2523 = tensor.dim %7180, %c1 : tensor<1x?x4096xi64> + %7181 = arith.index_cast %dim_2523 : index to i64 + %from_elements_2524 = tensor.from_elements %c1_i64, %7181, %c4096_i64, %c1_i64 : tensor<4xi64> + %7182 = stablehlo.dynamic_reshape %7180, %from_elements_2524 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7183 = stablehlo.concatenate %7176, %7179, %7182, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7184 = "stablehlo.gather"(%6917, %7183) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7185 = shape.shape_of %7184 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7186 = shape.num_elements %7185 : tensor<3xindex> -> index + %7187 = stablehlo.compute_reshape_shape %7186, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7188 = stablehlo.dynamic_reshape %7184, %7187 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7189 = stablehlo.dot %7188, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7190 = stablehlo.logistic %7189 : tensor + %7191 = shape.shape_of %7190 : tensor -> tensor<2xindex> + %7192 = shape.shape_of %7189 : tensor -> tensor<2xindex> + %7193 = shape.cstr_broadcastable %7191, %7192 : tensor<2xindex>, tensor<2xindex> + %7194 = shape.assuming %7193 -> (tensor) { + %19688 = shape.broadcast %7191, %7192 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7190, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7189, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7195 = shape.shape_of %7194 : tensor -> tensor<2xindex> + %7196 = shape.cstr_broadcastable %7195, %7192 : tensor<2xindex>, tensor<2xindex> + %7197 = shape.assuming %7196 -> (tensor) { + %19688 = shape.broadcast %7195, %7192 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7194, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7189, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7198 = stablehlo.dot %7197, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2525 = tensor.dim %7170, %c0 : tensor + %7199 = arith.index_cast %dim_2525 : index to i64 + %from_elements_2526 = tensor.from_elements %7199, %c1_i64 : tensor<2xi64> + %7200 = stablehlo.dynamic_reshape %7170, %from_elements_2526 : (tensor, tensor<2xi64>) -> tensor + %dim_2527 = tensor.dim %7167, %c0 : tensor + %7201 = arith.index_cast %dim_2527 : index to i64 + %from_elements_2528 = tensor.from_elements %7201, %c1_i64 : tensor<2xi64> + %7202 = stablehlo.dynamic_reshape %7167, %from_elements_2528 : (tensor, tensor<2xi64>) -> tensor + %7203 = stablehlo.concatenate %7200, %7202, dim = 1 : (tensor, tensor) -> tensor + %7204 = "stablehlo.gather"(%6946, %7203) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7205 = shape.shape_of %7198 : tensor -> tensor<2xindex> + %7206 = shape.shape_of %7204 : tensor -> tensor<2xindex> + %7207 = shape.cstr_broadcastable %7205, %7206 : tensor<2xindex>, tensor<2xindex> + %7208 = shape.assuming %7207 -> (tensor) { + %19688 = shape.broadcast %7205, %7206 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7198, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7204, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7209 = shape.shape_of %7208 : tensor -> tensor<2xindex> + %7210 = stablehlo.dynamic_broadcast_in_dim %7208, %7209, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7211 = stablehlo.dynamic_broadcast_in_dim %213, %7209, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7212 = stablehlo.multiply %7210, %7211 : tensor + %dim_2529 = tensor.dim %7172, %c0 : tensor + %7213 = arith.index_cast %dim_2529 : index to i64 + %dim_2530 = tensor.dim %7208, %c0 : tensor + %7214 = arith.index_cast %dim_2530 : index to i64 + %7215 = arith.maxsi %7213, %7214 : i64 + %7216 = arith.index_cast %7215 : i64 to index + %from_elements_2531 = tensor.from_elements %7216, %c4096 : tensor<2xindex> + %7217 = stablehlo.dynamic_broadcast_in_dim %7172, %from_elements_2531, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2532 = tensor.dim %7217, %c0 : tensor + %7218 = arith.index_cast %dim_2532 : index to i64 + %from_elements_2533 = tensor.from_elements %7218, %c4096_i64 : tensor<2xi64> + %7219 = stablehlo.real_dynamic_slice %7212, %c_22, %from_elements_2533, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2534 = tensor.from_elements %7218, %c4096_i64, %c1_i64 : tensor<3xi64> + %7220 = stablehlo.dynamic_reshape %7217, %from_elements_2534 : (tensor, tensor<3xi64>) -> tensor + %7221 = stablehlo.dynamic_iota %from_elements_2534, dim = 1 : (tensor<3xi64>) -> tensor + %7222 = stablehlo.concatenate %7220, %7221, dim = 2 : (tensor, tensor) -> tensor + %7223 = "stablehlo.scatter"(%7160, %7222, %7219) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7224 = stablehlo.slice %6906 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7225 = stablehlo.reshape %7224 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7226 = stablehlo.custom_call @byteir.non_zero(%7225) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2535 = tensor.dim %7226, %c0 : tensor + %7227 = arith.index_cast %dim_2535 : index to i64 + %from_elements_2536 = tensor.from_elements %7227, %c1_i64 : tensor<2xi64> + %7228 = stablehlo.real_dynamic_slice %7226, %c_22, %from_elements_2536, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2537 = tensor.dim %7228, %c0 : tensor + %7229 = arith.index_cast %dim_2537 : index to i64 + %from_elements_2538 = tensor.from_elements %7229 : tensor<1xi64> + %7230 = stablehlo.dynamic_reshape %7228, %from_elements_2538 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2539 = tensor.from_elements %7227, %c2_i64 : tensor<2xi64> + %7231 = stablehlo.real_dynamic_slice %7226, %c_24, %from_elements_2539, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2540 = tensor.dim %7231, %c0 : tensor + %7232 = arith.index_cast %dim_2540 : index to i64 + %from_elements_2541 = tensor.from_elements %7232 : tensor<1xi64> + %7233 = stablehlo.dynamic_reshape %7231, %from_elements_2541 : (tensor, tensor<1xi64>) -> tensor + %dim_2542 = tensor.dim %7233, %c0 : tensor + %7234 = arith.index_cast %dim_2542 : index to i64 + %from_elements_2543 = tensor.from_elements %7234, %c1_i64 : tensor<2xi64> + %7235 = stablehlo.dynamic_reshape %7233, %from_elements_2543 : (tensor, tensor<2xi64>) -> tensor + %dim_2544 = tensor.dim %7235, %c0 : tensor + %7236 = arith.index_cast %dim_2544 : index to i64 + %from_elements_2545 = tensor.from_elements %c1_i64, %7236, %c4096_i64 : tensor<3xi64> + %7237 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2545, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2546 = tensor.dim %7237, %c1 : tensor<1x?x4096xi64> + %7238 = arith.index_cast %dim_2546 : index to i64 + %from_elements_2547 = tensor.from_elements %c1_i64, %7238, %c4096_i64, %c1_i64 : tensor<4xi64> + %7239 = stablehlo.dynamic_reshape %7237, %from_elements_2547 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7240 = stablehlo.dynamic_broadcast_in_dim %7235, %from_elements_2545, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2548 = tensor.dim %7240, %c1 : tensor<1x?x4096xi64> + %7241 = arith.index_cast %dim_2548 : index to i64 + %from_elements_2549 = tensor.from_elements %c1_i64, %7241, %c4096_i64, %c1_i64 : tensor<4xi64> + %7242 = stablehlo.dynamic_reshape %7240, %from_elements_2549 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7243 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2545, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2550 = tensor.dim %7243, %c1 : tensor<1x?x4096xi64> + %7244 = arith.index_cast %dim_2550 : index to i64 + %from_elements_2551 = tensor.from_elements %c1_i64, %7244, %c4096_i64, %c1_i64 : tensor<4xi64> + %7245 = stablehlo.dynamic_reshape %7243, %from_elements_2551 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7246 = stablehlo.concatenate %7239, %7242, %7245, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7247 = "stablehlo.gather"(%6917, %7246) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7248 = shape.shape_of %7247 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7249 = shape.num_elements %7248 : tensor<3xindex> -> index + %7250 = stablehlo.compute_reshape_shape %7249, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7251 = stablehlo.dynamic_reshape %7247, %7250 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7252 = stablehlo.dot %7251, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7253 = stablehlo.logistic %7252 : tensor + %7254 = shape.shape_of %7253 : tensor -> tensor<2xindex> + %7255 = shape.shape_of %7252 : tensor -> tensor<2xindex> + %7256 = shape.cstr_broadcastable %7254, %7255 : tensor<2xindex>, tensor<2xindex> + %7257 = shape.assuming %7256 -> (tensor) { + %19688 = shape.broadcast %7254, %7255 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7253, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7252, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7258 = shape.shape_of %7257 : tensor -> tensor<2xindex> + %7259 = shape.cstr_broadcastable %7258, %7255 : tensor<2xindex>, tensor<2xindex> + %7260 = shape.assuming %7259 -> (tensor) { + %19688 = shape.broadcast %7258, %7255 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7257, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7252, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7261 = stablehlo.dot %7260, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2552 = tensor.dim %7233, %c0 : tensor + %7262 = arith.index_cast %dim_2552 : index to i64 + %from_elements_2553 = tensor.from_elements %7262, %c1_i64 : tensor<2xi64> + %7263 = stablehlo.dynamic_reshape %7233, %from_elements_2553 : (tensor, tensor<2xi64>) -> tensor + %dim_2554 = tensor.dim %7230, %c0 : tensor + %7264 = arith.index_cast %dim_2554 : index to i64 + %from_elements_2555 = tensor.from_elements %7264, %c1_i64 : tensor<2xi64> + %7265 = stablehlo.dynamic_reshape %7230, %from_elements_2555 : (tensor, tensor<2xi64>) -> tensor + %7266 = stablehlo.concatenate %7263, %7265, dim = 1 : (tensor, tensor) -> tensor + %7267 = "stablehlo.gather"(%6946, %7266) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7268 = shape.shape_of %7261 : tensor -> tensor<2xindex> + %7269 = shape.shape_of %7267 : tensor -> tensor<2xindex> + %7270 = shape.cstr_broadcastable %7268, %7269 : tensor<2xindex>, tensor<2xindex> + %7271 = shape.assuming %7270 -> (tensor) { + %19688 = shape.broadcast %7268, %7269 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7261, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7267, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7272 = shape.shape_of %7271 : tensor -> tensor<2xindex> + %7273 = stablehlo.dynamic_broadcast_in_dim %7271, %7272, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7274 = stablehlo.dynamic_broadcast_in_dim %213, %7272, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7275 = stablehlo.multiply %7273, %7274 : tensor + %dim_2556 = tensor.dim %7235, %c0 : tensor + %7276 = arith.index_cast %dim_2556 : index to i64 + %dim_2557 = tensor.dim %7271, %c0 : tensor + %7277 = arith.index_cast %dim_2557 : index to i64 + %7278 = arith.maxsi %7276, %7277 : i64 + %7279 = arith.index_cast %7278 : i64 to index + %from_elements_2558 = tensor.from_elements %7279, %c4096 : tensor<2xindex> + %7280 = stablehlo.dynamic_broadcast_in_dim %7235, %from_elements_2558, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2559 = tensor.dim %7280, %c0 : tensor + %7281 = arith.index_cast %dim_2559 : index to i64 + %from_elements_2560 = tensor.from_elements %7281, %c4096_i64 : tensor<2xi64> + %7282 = stablehlo.real_dynamic_slice %7275, %c_22, %from_elements_2560, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2561 = tensor.from_elements %7281, %c4096_i64, %c1_i64 : tensor<3xi64> + %7283 = stablehlo.dynamic_reshape %7280, %from_elements_2561 : (tensor, tensor<3xi64>) -> tensor + %7284 = stablehlo.dynamic_iota %from_elements_2561, dim = 1 : (tensor<3xi64>) -> tensor + %7285 = stablehlo.concatenate %7283, %7284, dim = 2 : (tensor, tensor) -> tensor + %7286 = "stablehlo.scatter"(%7223, %7285, %7282) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7287 = stablehlo.slice %6906 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7288 = stablehlo.reshape %7287 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7289 = stablehlo.custom_call @byteir.non_zero(%7288) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2562 = tensor.dim %7289, %c0 : tensor + %7290 = arith.index_cast %dim_2562 : index to i64 + %from_elements_2563 = tensor.from_elements %7290, %c1_i64 : tensor<2xi64> + %7291 = stablehlo.real_dynamic_slice %7289, %c_22, %from_elements_2563, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2564 = tensor.dim %7291, %c0 : tensor + %7292 = arith.index_cast %dim_2564 : index to i64 + %from_elements_2565 = tensor.from_elements %7292 : tensor<1xi64> + %7293 = stablehlo.dynamic_reshape %7291, %from_elements_2565 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2566 = tensor.from_elements %7290, %c2_i64 : tensor<2xi64> + %7294 = stablehlo.real_dynamic_slice %7289, %c_24, %from_elements_2566, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2567 = tensor.dim %7294, %c0 : tensor + %7295 = arith.index_cast %dim_2567 : index to i64 + %from_elements_2568 = tensor.from_elements %7295 : tensor<1xi64> + %7296 = stablehlo.dynamic_reshape %7294, %from_elements_2568 : (tensor, tensor<1xi64>) -> tensor + %dim_2569 = tensor.dim %7296, %c0 : tensor + %7297 = arith.index_cast %dim_2569 : index to i64 + %from_elements_2570 = tensor.from_elements %7297, %c1_i64 : tensor<2xi64> + %7298 = stablehlo.dynamic_reshape %7296, %from_elements_2570 : (tensor, tensor<2xi64>) -> tensor + %dim_2571 = tensor.dim %7298, %c0 : tensor + %7299 = arith.index_cast %dim_2571 : index to i64 + %from_elements_2572 = tensor.from_elements %c1_i64, %7299, %c4096_i64 : tensor<3xi64> + %7300 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2572, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2573 = tensor.dim %7300, %c1 : tensor<1x?x4096xi64> + %7301 = arith.index_cast %dim_2573 : index to i64 + %from_elements_2574 = tensor.from_elements %c1_i64, %7301, %c4096_i64, %c1_i64 : tensor<4xi64> + %7302 = stablehlo.dynamic_reshape %7300, %from_elements_2574 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7303 = stablehlo.dynamic_broadcast_in_dim %7298, %from_elements_2572, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2575 = tensor.dim %7303, %c1 : tensor<1x?x4096xi64> + %7304 = arith.index_cast %dim_2575 : index to i64 + %from_elements_2576 = tensor.from_elements %c1_i64, %7304, %c4096_i64, %c1_i64 : tensor<4xi64> + %7305 = stablehlo.dynamic_reshape %7303, %from_elements_2576 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7306 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2572, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2577 = tensor.dim %7306, %c1 : tensor<1x?x4096xi64> + %7307 = arith.index_cast %dim_2577 : index to i64 + %from_elements_2578 = tensor.from_elements %c1_i64, %7307, %c4096_i64, %c1_i64 : tensor<4xi64> + %7308 = stablehlo.dynamic_reshape %7306, %from_elements_2578 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7309 = stablehlo.concatenate %7302, %7305, %7308, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7310 = "stablehlo.gather"(%6917, %7309) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7311 = shape.shape_of %7310 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7312 = shape.num_elements %7311 : tensor<3xindex> -> index + %7313 = stablehlo.compute_reshape_shape %7312, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7314 = stablehlo.dynamic_reshape %7310, %7313 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7315 = stablehlo.dot %7314, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7316 = stablehlo.logistic %7315 : tensor + %7317 = shape.shape_of %7316 : tensor -> tensor<2xindex> + %7318 = shape.shape_of %7315 : tensor -> tensor<2xindex> + %7319 = shape.cstr_broadcastable %7317, %7318 : tensor<2xindex>, tensor<2xindex> + %7320 = shape.assuming %7319 -> (tensor) { + %19688 = shape.broadcast %7317, %7318 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7316, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7315, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7321 = shape.shape_of %7320 : tensor -> tensor<2xindex> + %7322 = shape.cstr_broadcastable %7321, %7318 : tensor<2xindex>, tensor<2xindex> + %7323 = shape.assuming %7322 -> (tensor) { + %19688 = shape.broadcast %7321, %7318 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7320, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7315, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7324 = stablehlo.dot %7323, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2579 = tensor.dim %7296, %c0 : tensor + %7325 = arith.index_cast %dim_2579 : index to i64 + %from_elements_2580 = tensor.from_elements %7325, %c1_i64 : tensor<2xi64> + %7326 = stablehlo.dynamic_reshape %7296, %from_elements_2580 : (tensor, tensor<2xi64>) -> tensor + %dim_2581 = tensor.dim %7293, %c0 : tensor + %7327 = arith.index_cast %dim_2581 : index to i64 + %from_elements_2582 = tensor.from_elements %7327, %c1_i64 : tensor<2xi64> + %7328 = stablehlo.dynamic_reshape %7293, %from_elements_2582 : (tensor, tensor<2xi64>) -> tensor + %7329 = stablehlo.concatenate %7326, %7328, dim = 1 : (tensor, tensor) -> tensor + %7330 = "stablehlo.gather"(%6946, %7329) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7331 = shape.shape_of %7324 : tensor -> tensor<2xindex> + %7332 = shape.shape_of %7330 : tensor -> tensor<2xindex> + %7333 = shape.cstr_broadcastable %7331, %7332 : tensor<2xindex>, tensor<2xindex> + %7334 = shape.assuming %7333 -> (tensor) { + %19688 = shape.broadcast %7331, %7332 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7324, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7330, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7335 = shape.shape_of %7334 : tensor -> tensor<2xindex> + %7336 = stablehlo.dynamic_broadcast_in_dim %7334, %7335, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7337 = stablehlo.dynamic_broadcast_in_dim %213, %7335, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7338 = stablehlo.multiply %7336, %7337 : tensor + %dim_2583 = tensor.dim %7298, %c0 : tensor + %7339 = arith.index_cast %dim_2583 : index to i64 + %dim_2584 = tensor.dim %7334, %c0 : tensor + %7340 = arith.index_cast %dim_2584 : index to i64 + %7341 = arith.maxsi %7339, %7340 : i64 + %7342 = arith.index_cast %7341 : i64 to index + %from_elements_2585 = tensor.from_elements %7342, %c4096 : tensor<2xindex> + %7343 = stablehlo.dynamic_broadcast_in_dim %7298, %from_elements_2585, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2586 = tensor.dim %7343, %c0 : tensor + %7344 = arith.index_cast %dim_2586 : index to i64 + %from_elements_2587 = tensor.from_elements %7344, %c4096_i64 : tensor<2xi64> + %7345 = stablehlo.real_dynamic_slice %7338, %c_22, %from_elements_2587, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2588 = tensor.from_elements %7344, %c4096_i64, %c1_i64 : tensor<3xi64> + %7346 = stablehlo.dynamic_reshape %7343, %from_elements_2588 : (tensor, tensor<3xi64>) -> tensor + %7347 = stablehlo.dynamic_iota %from_elements_2588, dim = 1 : (tensor<3xi64>) -> tensor + %7348 = stablehlo.concatenate %7346, %7347, dim = 2 : (tensor, tensor) -> tensor + %7349 = "stablehlo.scatter"(%7286, %7348, %7345) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7350 = stablehlo.slice %6906 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7351 = stablehlo.reshape %7350 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7352 = stablehlo.custom_call @byteir.non_zero(%7351) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2589 = tensor.dim %7352, %c0 : tensor + %7353 = arith.index_cast %dim_2589 : index to i64 + %from_elements_2590 = tensor.from_elements %7353, %c1_i64 : tensor<2xi64> + %7354 = stablehlo.real_dynamic_slice %7352, %c_22, %from_elements_2590, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2591 = tensor.dim %7354, %c0 : tensor + %7355 = arith.index_cast %dim_2591 : index to i64 + %from_elements_2592 = tensor.from_elements %7355 : tensor<1xi64> + %7356 = stablehlo.dynamic_reshape %7354, %from_elements_2592 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2593 = tensor.from_elements %7353, %c2_i64 : tensor<2xi64> + %7357 = stablehlo.real_dynamic_slice %7352, %c_24, %from_elements_2593, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2594 = tensor.dim %7357, %c0 : tensor + %7358 = arith.index_cast %dim_2594 : index to i64 + %from_elements_2595 = tensor.from_elements %7358 : tensor<1xi64> + %7359 = stablehlo.dynamic_reshape %7357, %from_elements_2595 : (tensor, tensor<1xi64>) -> tensor + %dim_2596 = tensor.dim %7359, %c0 : tensor + %7360 = arith.index_cast %dim_2596 : index to i64 + %from_elements_2597 = tensor.from_elements %7360, %c1_i64 : tensor<2xi64> + %7361 = stablehlo.dynamic_reshape %7359, %from_elements_2597 : (tensor, tensor<2xi64>) -> tensor + %dim_2598 = tensor.dim %7361, %c0 : tensor + %7362 = arith.index_cast %dim_2598 : index to i64 + %from_elements_2599 = tensor.from_elements %c1_i64, %7362, %c4096_i64 : tensor<3xi64> + %7363 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2599, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2600 = tensor.dim %7363, %c1 : tensor<1x?x4096xi64> + %7364 = arith.index_cast %dim_2600 : index to i64 + %from_elements_2601 = tensor.from_elements %c1_i64, %7364, %c4096_i64, %c1_i64 : tensor<4xi64> + %7365 = stablehlo.dynamic_reshape %7363, %from_elements_2601 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7366 = stablehlo.dynamic_broadcast_in_dim %7361, %from_elements_2599, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2602 = tensor.dim %7366, %c1 : tensor<1x?x4096xi64> + %7367 = arith.index_cast %dim_2602 : index to i64 + %from_elements_2603 = tensor.from_elements %c1_i64, %7367, %c4096_i64, %c1_i64 : tensor<4xi64> + %7368 = stablehlo.dynamic_reshape %7366, %from_elements_2603 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7369 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2599, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2604 = tensor.dim %7369, %c1 : tensor<1x?x4096xi64> + %7370 = arith.index_cast %dim_2604 : index to i64 + %from_elements_2605 = tensor.from_elements %c1_i64, %7370, %c4096_i64, %c1_i64 : tensor<4xi64> + %7371 = stablehlo.dynamic_reshape %7369, %from_elements_2605 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7372 = stablehlo.concatenate %7365, %7368, %7371, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7373 = "stablehlo.gather"(%6917, %7372) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7374 = shape.shape_of %7373 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7375 = shape.num_elements %7374 : tensor<3xindex> -> index + %7376 = stablehlo.compute_reshape_shape %7375, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7377 = stablehlo.dynamic_reshape %7373, %7376 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7378 = stablehlo.dot %7377, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7379 = stablehlo.logistic %7378 : tensor + %7380 = shape.shape_of %7379 : tensor -> tensor<2xindex> + %7381 = shape.shape_of %7378 : tensor -> tensor<2xindex> + %7382 = shape.cstr_broadcastable %7380, %7381 : tensor<2xindex>, tensor<2xindex> + %7383 = shape.assuming %7382 -> (tensor) { + %19688 = shape.broadcast %7380, %7381 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7379, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7378, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7384 = shape.shape_of %7383 : tensor -> tensor<2xindex> + %7385 = shape.cstr_broadcastable %7384, %7381 : tensor<2xindex>, tensor<2xindex> + %7386 = shape.assuming %7385 -> (tensor) { + %19688 = shape.broadcast %7384, %7381 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7383, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7378, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7387 = stablehlo.dot %7386, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2606 = tensor.dim %7359, %c0 : tensor + %7388 = arith.index_cast %dim_2606 : index to i64 + %from_elements_2607 = tensor.from_elements %7388, %c1_i64 : tensor<2xi64> + %7389 = stablehlo.dynamic_reshape %7359, %from_elements_2607 : (tensor, tensor<2xi64>) -> tensor + %dim_2608 = tensor.dim %7356, %c0 : tensor + %7390 = arith.index_cast %dim_2608 : index to i64 + %from_elements_2609 = tensor.from_elements %7390, %c1_i64 : tensor<2xi64> + %7391 = stablehlo.dynamic_reshape %7356, %from_elements_2609 : (tensor, tensor<2xi64>) -> tensor + %7392 = stablehlo.concatenate %7389, %7391, dim = 1 : (tensor, tensor) -> tensor + %7393 = "stablehlo.gather"(%6946, %7392) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7394 = shape.shape_of %7387 : tensor -> tensor<2xindex> + %7395 = shape.shape_of %7393 : tensor -> tensor<2xindex> + %7396 = shape.cstr_broadcastable %7394, %7395 : tensor<2xindex>, tensor<2xindex> + %7397 = shape.assuming %7396 -> (tensor) { + %19688 = shape.broadcast %7394, %7395 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7387, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7393, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7398 = shape.shape_of %7397 : tensor -> tensor<2xindex> + %7399 = stablehlo.dynamic_broadcast_in_dim %7397, %7398, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7400 = stablehlo.dynamic_broadcast_in_dim %213, %7398, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7401 = stablehlo.multiply %7399, %7400 : tensor + %dim_2610 = tensor.dim %7361, %c0 : tensor + %7402 = arith.index_cast %dim_2610 : index to i64 + %dim_2611 = tensor.dim %7397, %c0 : tensor + %7403 = arith.index_cast %dim_2611 : index to i64 + %7404 = arith.maxsi %7402, %7403 : i64 + %7405 = arith.index_cast %7404 : i64 to index + %from_elements_2612 = tensor.from_elements %7405, %c4096 : tensor<2xindex> + %7406 = stablehlo.dynamic_broadcast_in_dim %7361, %from_elements_2612, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2613 = tensor.dim %7406, %c0 : tensor + %7407 = arith.index_cast %dim_2613 : index to i64 + %from_elements_2614 = tensor.from_elements %7407, %c4096_i64 : tensor<2xi64> + %7408 = stablehlo.real_dynamic_slice %7401, %c_22, %from_elements_2614, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2615 = tensor.from_elements %7407, %c4096_i64, %c1_i64 : tensor<3xi64> + %7409 = stablehlo.dynamic_reshape %7406, %from_elements_2615 : (tensor, tensor<3xi64>) -> tensor + %7410 = stablehlo.dynamic_iota %from_elements_2615, dim = 1 : (tensor<3xi64>) -> tensor + %7411 = stablehlo.concatenate %7409, %7410, dim = 2 : (tensor, tensor) -> tensor + %7412 = "stablehlo.scatter"(%7349, %7411, %7408) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7413 = stablehlo.reshape %7412 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %7414 = stablehlo.add %6879, %7413 : tensor<3x1x4096xf32> + %7415 = stablehlo.broadcast_in_dim %7414, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %7416 = stablehlo.power %7415, %15 : tensor<3x1x4096xf32> + %7417 = stablehlo.reduce(%7416 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %7418 = stablehlo.reshape %7417 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %7419 = stablehlo.broadcast_in_dim %7418, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %7420 = stablehlo.divide %7419, %21 : tensor<3x1x1xf32> + %7421 = stablehlo.broadcast_in_dim %7420, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %7422 = stablehlo.add %7421, %25 : tensor<3x1x1xf32> + %7423 = stablehlo.rsqrt %7422 : tensor<3x1x1xf32> + %7424 = stablehlo.broadcast_in_dim %7423, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %7425 = stablehlo.multiply %7415, %7424 : tensor<3x1x4096xf32> + %7426 = stablehlo.broadcast_in_dim %7425, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %7427 = stablehlo.multiply %7426, %31 : tensor<3x1x4096xf32> + %7428 = stablehlo.reshape %7427 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %7429 = stablehlo.dot %7428, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %7430 = stablehlo.reshape %7429 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %7431 = stablehlo.dot %7428, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %7432 = stablehlo.reshape %7431 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %7433 = stablehlo.reshape %7430 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %7434 = stablehlo.transpose %7433, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %7435 = stablehlo.reshape %7432 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %7436 = stablehlo.transpose %7435, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %7437 = stablehlo.slice %arg24 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %7438 = stablehlo.slice %arg25 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %7439 = "stablehlo.gather"(%7437, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %7440 = stablehlo.reshape %7439 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %7441 = "stablehlo.gather"(%7438, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %7442 = stablehlo.reshape %7441 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %7443 = stablehlo.broadcast_in_dim %7434, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %7444 = stablehlo.broadcast_in_dim %7440, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %7445 = stablehlo.multiply %7443, %7444 : tensor<3x32x1x128xf32> + %7446 = stablehlo.slice %7434 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %7447 = stablehlo.slice %7434 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %7448 = stablehlo.negate %7447 : tensor<3x32x1x64xf32> + %7449 = stablehlo.concatenate %7448, %7446, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %7450 = stablehlo.broadcast_in_dim %7449, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %7451 = stablehlo.broadcast_in_dim %7442, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %7452 = stablehlo.multiply %7450, %7451 : tensor<3x32x1x128xf32> + %7453 = stablehlo.add %7445, %7452 : tensor<3x32x1x128xf32> + %7454 = stablehlo.broadcast_in_dim %7436, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %7455 = stablehlo.broadcast_in_dim %7440, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %7456 = stablehlo.multiply %7454, %7455 : tensor<3x8x1x128xf32> + %7457 = stablehlo.slice %7436 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %7458 = stablehlo.slice %7436 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %7459 = stablehlo.negate %7458 : tensor<3x8x1x64xf32> + %7460 = stablehlo.concatenate %7459, %7457, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %7461 = stablehlo.broadcast_in_dim %7460, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %7462 = stablehlo.broadcast_in_dim %7442, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %7463 = stablehlo.multiply %7461, %7462 : tensor<3x8x1x128xf32> + %7464 = stablehlo.add %7456, %7463 : tensor<3x8x1x128xf32> + %7465 = stablehlo.concatenate %arg89, %7464, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %7466 = stablehlo.concatenate %arg90, %7436, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %7467 = stablehlo.reshape %7465 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %7468 = stablehlo.broadcast_in_dim %7467, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %7469 = stablehlo.reshape %7468 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %7470 = stablehlo.reshape %7466 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %7471 = stablehlo.broadcast_in_dim %7470, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %7472 = stablehlo.reshape %7471 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %7473 = stablehlo.transpose %7469, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %7474 = stablehlo.reshape %7453 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %7475 = stablehlo.reshape %7473 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %7476 = stablehlo.broadcast_in_dim %7475, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %7477 = stablehlo.dot_general %7474, %7476, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %7478 = stablehlo.reshape %7477 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %7479 = stablehlo.broadcast_in_dim %7478, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %7480 = stablehlo.divide %7479, %89 : tensor<3x32x1x8xf32> + %7481 = stablehlo.custom_call @byteir.softmax(%7480) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %7482 = stablehlo.reshape %7481 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %7483 = stablehlo.reshape %7472 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %7484 = stablehlo.broadcast_in_dim %7483, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %7485 = stablehlo.dot_general %7482, %7484, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %7486 = stablehlo.reshape %7485 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %7487 = stablehlo.transpose %7486, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %7488 = stablehlo.reshape %7487 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %7489 = stablehlo.reshape %7488 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %7490 = stablehlo.dot %7489, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %7491 = stablehlo.reshape %7490 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %7492 = stablehlo.add %7414, %7491 : tensor<3x1x4096xf32> + %7493 = stablehlo.broadcast_in_dim %7492, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %7494 = stablehlo.power %7493, %15 : tensor<3x1x4096xf32> + %7495 = stablehlo.reduce(%7494 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %7496 = stablehlo.reshape %7495 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %7497 = stablehlo.broadcast_in_dim %7496, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %7498 = stablehlo.divide %7497, %21 : tensor<3x1x1xf32> + %7499 = stablehlo.broadcast_in_dim %7498, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %7500 = stablehlo.add %7499, %25 : tensor<3x1x1xf32> + %7501 = stablehlo.rsqrt %7500 : tensor<3x1x1xf32> + %7502 = stablehlo.broadcast_in_dim %7501, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %7503 = stablehlo.multiply %7493, %7502 : tensor<3x1x4096xf32> + %7504 = stablehlo.broadcast_in_dim %7503, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %7505 = stablehlo.multiply %7504, %31 : tensor<3x1x4096xf32> + %7506 = stablehlo.reshape %7505 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %7507 = stablehlo.dot %7506, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %7508 = stablehlo.custom_call @byteir.softmax(%7507) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %7509:2 = stablehlo.custom_call @byteir.top_k(%7508) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %7510 = stablehlo.reduce(%7509#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %7511 = stablehlo.reshape %7510 : (tensor<3xf32>) -> tensor<3x1xf32> + %7512 = stablehlo.broadcast_in_dim %7509#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %7513 = stablehlo.broadcast_in_dim %7511, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %7514 = stablehlo.divide %7512, %7513 : tensor<3x2xf32> + %7515 = stablehlo.reshape %7509#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %7516 = stablehlo.broadcast_in_dim %7515, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %7517 = stablehlo.compare EQ, %7516, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %7518 = stablehlo.convert %7517 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %7519 = stablehlo.transpose %7518, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %7520 = stablehlo.slice %7519 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7521 = stablehlo.reshape %7520 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7522 = stablehlo.custom_call @byteir.non_zero(%7521) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2616 = tensor.dim %7522, %c0 : tensor + %7523 = arith.index_cast %dim_2616 : index to i64 + %from_elements_2617 = tensor.from_elements %7523, %c1_i64 : tensor<2xi64> + %7524 = stablehlo.real_dynamic_slice %7522, %c_22, %from_elements_2617, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2618 = tensor.dim %7524, %c0 : tensor + %7525 = arith.index_cast %dim_2618 : index to i64 + %from_elements_2619 = tensor.from_elements %7525 : tensor<1xi64> + %7526 = stablehlo.dynamic_reshape %7524, %from_elements_2619 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2620 = tensor.from_elements %7523, %c2_i64 : tensor<2xi64> + %7527 = stablehlo.real_dynamic_slice %7522, %c_24, %from_elements_2620, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2621 = tensor.dim %7527, %c0 : tensor + %7528 = arith.index_cast %dim_2621 : index to i64 + %from_elements_2622 = tensor.from_elements %7528 : tensor<1xi64> + %7529 = stablehlo.dynamic_reshape %7527, %from_elements_2622 : (tensor, tensor<1xi64>) -> tensor + %7530 = stablehlo.reshape %7506 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_2623 = tensor.dim %7529, %c0 : tensor + %7531 = arith.index_cast %dim_2623 : index to i64 + %from_elements_2624 = tensor.from_elements %7531, %c1_i64 : tensor<2xi64> + %7532 = stablehlo.dynamic_reshape %7529, %from_elements_2624 : (tensor, tensor<2xi64>) -> tensor + %dim_2625 = tensor.dim %7532, %c0 : tensor + %7533 = arith.index_cast %dim_2625 : index to i64 + %from_elements_2626 = tensor.from_elements %c1_i64, %7533, %c4096_i64 : tensor<3xi64> + %7534 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2626, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2627 = tensor.dim %7534, %c1 : tensor<1x?x4096xi64> + %7535 = arith.index_cast %dim_2627 : index to i64 + %from_elements_2628 = tensor.from_elements %c1_i64, %7535, %c4096_i64, %c1_i64 : tensor<4xi64> + %7536 = stablehlo.dynamic_reshape %7534, %from_elements_2628 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7537 = stablehlo.dynamic_broadcast_in_dim %7532, %from_elements_2626, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2629 = tensor.dim %7537, %c1 : tensor<1x?x4096xi64> + %7538 = arith.index_cast %dim_2629 : index to i64 + %from_elements_2630 = tensor.from_elements %c1_i64, %7538, %c4096_i64, %c1_i64 : tensor<4xi64> + %7539 = stablehlo.dynamic_reshape %7537, %from_elements_2630 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7540 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2626, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2631 = tensor.dim %7540, %c1 : tensor<1x?x4096xi64> + %7541 = arith.index_cast %dim_2631 : index to i64 + %from_elements_2632 = tensor.from_elements %c1_i64, %7541, %c4096_i64, %c1_i64 : tensor<4xi64> + %7542 = stablehlo.dynamic_reshape %7540, %from_elements_2632 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7543 = stablehlo.concatenate %7536, %7539, %7542, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7544 = "stablehlo.gather"(%7530, %7543) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7545 = shape.shape_of %7544 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7546 = shape.num_elements %7545 : tensor<3xindex> -> index + %7547 = stablehlo.compute_reshape_shape %7546, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7548 = stablehlo.dynamic_reshape %7544, %7547 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7549 = stablehlo.dot %7548, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7550 = stablehlo.logistic %7549 : tensor + %7551 = shape.shape_of %7550 : tensor -> tensor<2xindex> + %7552 = shape.shape_of %7549 : tensor -> tensor<2xindex> + %7553 = shape.cstr_broadcastable %7551, %7552 : tensor<2xindex>, tensor<2xindex> + %7554 = shape.assuming %7553 -> (tensor) { + %19688 = shape.broadcast %7551, %7552 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7550, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7549, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7555 = shape.shape_of %7554 : tensor -> tensor<2xindex> + %7556 = shape.cstr_broadcastable %7555, %7552 : tensor<2xindex>, tensor<2xindex> + %7557 = shape.assuming %7556 -> (tensor) { + %19688 = shape.broadcast %7555, %7552 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7554, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7549, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7558 = stablehlo.dot %7557, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %7559 = stablehlo.reshape %7514 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_2633 = tensor.dim %7529, %c0 : tensor + %7560 = arith.index_cast %dim_2633 : index to i64 + %from_elements_2634 = tensor.from_elements %7560, %c1_i64 : tensor<2xi64> + %7561 = stablehlo.dynamic_reshape %7529, %from_elements_2634 : (tensor, tensor<2xi64>) -> tensor + %dim_2635 = tensor.dim %7526, %c0 : tensor + %7562 = arith.index_cast %dim_2635 : index to i64 + %from_elements_2636 = tensor.from_elements %7562, %c1_i64 : tensor<2xi64> + %7563 = stablehlo.dynamic_reshape %7526, %from_elements_2636 : (tensor, tensor<2xi64>) -> tensor + %7564 = stablehlo.concatenate %7561, %7563, dim = 1 : (tensor, tensor) -> tensor + %7565 = "stablehlo.gather"(%7559, %7564) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7566 = shape.shape_of %7558 : tensor -> tensor<2xindex> + %7567 = shape.shape_of %7565 : tensor -> tensor<2xindex> + %7568 = shape.cstr_broadcastable %7566, %7567 : tensor<2xindex>, tensor<2xindex> + %7569 = shape.assuming %7568 -> (tensor) { + %19688 = shape.broadcast %7566, %7567 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7558, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7565, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7570 = shape.shape_of %7569 : tensor -> tensor<2xindex> + %7571 = stablehlo.dynamic_broadcast_in_dim %7569, %7570, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7572 = stablehlo.dynamic_broadcast_in_dim %213, %7570, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7573 = stablehlo.multiply %7571, %7572 : tensor + %dim_2637 = tensor.dim %7532, %c0 : tensor + %7574 = arith.index_cast %dim_2637 : index to i64 + %dim_2638 = tensor.dim %7569, %c0 : tensor + %7575 = arith.index_cast %dim_2638 : index to i64 + %7576 = arith.maxsi %7574, %7575 : i64 + %7577 = arith.index_cast %7576 : i64 to index + %from_elements_2639 = tensor.from_elements %7577, %c4096 : tensor<2xindex> + %7578 = stablehlo.dynamic_broadcast_in_dim %7532, %from_elements_2639, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2640 = tensor.dim %7578, %c0 : tensor + %7579 = arith.index_cast %dim_2640 : index to i64 + %from_elements_2641 = tensor.from_elements %7579, %c4096_i64 : tensor<2xi64> + %7580 = stablehlo.real_dynamic_slice %7573, %c_22, %from_elements_2641, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2642 = tensor.from_elements %7579, %c4096_i64, %c1_i64 : tensor<3xi64> + %7581 = stablehlo.dynamic_reshape %7578, %from_elements_2642 : (tensor, tensor<3xi64>) -> tensor + %7582 = stablehlo.dynamic_iota %from_elements_2642, dim = 1 : (tensor<3xi64>) -> tensor + %7583 = stablehlo.concatenate %7581, %7582, dim = 2 : (tensor, tensor) -> tensor + %7584 = "stablehlo.scatter"(%cst_2, %7583, %7580) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7585 = stablehlo.slice %7519 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7586 = stablehlo.reshape %7585 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7587 = stablehlo.custom_call @byteir.non_zero(%7586) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2643 = tensor.dim %7587, %c0 : tensor + %7588 = arith.index_cast %dim_2643 : index to i64 + %from_elements_2644 = tensor.from_elements %7588, %c1_i64 : tensor<2xi64> + %7589 = stablehlo.real_dynamic_slice %7587, %c_22, %from_elements_2644, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2645 = tensor.dim %7589, %c0 : tensor + %7590 = arith.index_cast %dim_2645 : index to i64 + %from_elements_2646 = tensor.from_elements %7590 : tensor<1xi64> + %7591 = stablehlo.dynamic_reshape %7589, %from_elements_2646 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2647 = tensor.from_elements %7588, %c2_i64 : tensor<2xi64> + %7592 = stablehlo.real_dynamic_slice %7587, %c_24, %from_elements_2647, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2648 = tensor.dim %7592, %c0 : tensor + %7593 = arith.index_cast %dim_2648 : index to i64 + %from_elements_2649 = tensor.from_elements %7593 : tensor<1xi64> + %7594 = stablehlo.dynamic_reshape %7592, %from_elements_2649 : (tensor, tensor<1xi64>) -> tensor + %dim_2650 = tensor.dim %7594, %c0 : tensor + %7595 = arith.index_cast %dim_2650 : index to i64 + %from_elements_2651 = tensor.from_elements %7595, %c1_i64 : tensor<2xi64> + %7596 = stablehlo.dynamic_reshape %7594, %from_elements_2651 : (tensor, tensor<2xi64>) -> tensor + %dim_2652 = tensor.dim %7596, %c0 : tensor + %7597 = arith.index_cast %dim_2652 : index to i64 + %from_elements_2653 = tensor.from_elements %c1_i64, %7597, %c4096_i64 : tensor<3xi64> + %7598 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2653, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2654 = tensor.dim %7598, %c1 : tensor<1x?x4096xi64> + %7599 = arith.index_cast %dim_2654 : index to i64 + %from_elements_2655 = tensor.from_elements %c1_i64, %7599, %c4096_i64, %c1_i64 : tensor<4xi64> + %7600 = stablehlo.dynamic_reshape %7598, %from_elements_2655 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7601 = stablehlo.dynamic_broadcast_in_dim %7596, %from_elements_2653, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2656 = tensor.dim %7601, %c1 : tensor<1x?x4096xi64> + %7602 = arith.index_cast %dim_2656 : index to i64 + %from_elements_2657 = tensor.from_elements %c1_i64, %7602, %c4096_i64, %c1_i64 : tensor<4xi64> + %7603 = stablehlo.dynamic_reshape %7601, %from_elements_2657 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7604 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2653, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2658 = tensor.dim %7604, %c1 : tensor<1x?x4096xi64> + %7605 = arith.index_cast %dim_2658 : index to i64 + %from_elements_2659 = tensor.from_elements %c1_i64, %7605, %c4096_i64, %c1_i64 : tensor<4xi64> + %7606 = stablehlo.dynamic_reshape %7604, %from_elements_2659 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7607 = stablehlo.concatenate %7600, %7603, %7606, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7608 = "stablehlo.gather"(%7530, %7607) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7609 = shape.shape_of %7608 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7610 = shape.num_elements %7609 : tensor<3xindex> -> index + %7611 = stablehlo.compute_reshape_shape %7610, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7612 = stablehlo.dynamic_reshape %7608, %7611 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7613 = stablehlo.dot %7612, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7614 = stablehlo.logistic %7613 : tensor + %7615 = shape.shape_of %7614 : tensor -> tensor<2xindex> + %7616 = shape.shape_of %7613 : tensor -> tensor<2xindex> + %7617 = shape.cstr_broadcastable %7615, %7616 : tensor<2xindex>, tensor<2xindex> + %7618 = shape.assuming %7617 -> (tensor) { + %19688 = shape.broadcast %7615, %7616 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7614, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7613, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7619 = shape.shape_of %7618 : tensor -> tensor<2xindex> + %7620 = shape.cstr_broadcastable %7619, %7616 : tensor<2xindex>, tensor<2xindex> + %7621 = shape.assuming %7620 -> (tensor) { + %19688 = shape.broadcast %7619, %7616 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7618, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7613, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7622 = stablehlo.dot %7621, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2660 = tensor.dim %7594, %c0 : tensor + %7623 = arith.index_cast %dim_2660 : index to i64 + %from_elements_2661 = tensor.from_elements %7623, %c1_i64 : tensor<2xi64> + %7624 = stablehlo.dynamic_reshape %7594, %from_elements_2661 : (tensor, tensor<2xi64>) -> tensor + %dim_2662 = tensor.dim %7591, %c0 : tensor + %7625 = arith.index_cast %dim_2662 : index to i64 + %from_elements_2663 = tensor.from_elements %7625, %c1_i64 : tensor<2xi64> + %7626 = stablehlo.dynamic_reshape %7591, %from_elements_2663 : (tensor, tensor<2xi64>) -> tensor + %7627 = stablehlo.concatenate %7624, %7626, dim = 1 : (tensor, tensor) -> tensor + %7628 = "stablehlo.gather"(%7559, %7627) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7629 = shape.shape_of %7622 : tensor -> tensor<2xindex> + %7630 = shape.shape_of %7628 : tensor -> tensor<2xindex> + %7631 = shape.cstr_broadcastable %7629, %7630 : tensor<2xindex>, tensor<2xindex> + %7632 = shape.assuming %7631 -> (tensor) { + %19688 = shape.broadcast %7629, %7630 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7622, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7628, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7633 = shape.shape_of %7632 : tensor -> tensor<2xindex> + %7634 = stablehlo.dynamic_broadcast_in_dim %7632, %7633, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7635 = stablehlo.dynamic_broadcast_in_dim %213, %7633, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7636 = stablehlo.multiply %7634, %7635 : tensor + %dim_2664 = tensor.dim %7596, %c0 : tensor + %7637 = arith.index_cast %dim_2664 : index to i64 + %dim_2665 = tensor.dim %7632, %c0 : tensor + %7638 = arith.index_cast %dim_2665 : index to i64 + %7639 = arith.maxsi %7637, %7638 : i64 + %7640 = arith.index_cast %7639 : i64 to index + %from_elements_2666 = tensor.from_elements %7640, %c4096 : tensor<2xindex> + %7641 = stablehlo.dynamic_broadcast_in_dim %7596, %from_elements_2666, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2667 = tensor.dim %7641, %c0 : tensor + %7642 = arith.index_cast %dim_2667 : index to i64 + %from_elements_2668 = tensor.from_elements %7642, %c4096_i64 : tensor<2xi64> + %7643 = stablehlo.real_dynamic_slice %7636, %c_22, %from_elements_2668, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2669 = tensor.from_elements %7642, %c4096_i64, %c1_i64 : tensor<3xi64> + %7644 = stablehlo.dynamic_reshape %7641, %from_elements_2669 : (tensor, tensor<3xi64>) -> tensor + %7645 = stablehlo.dynamic_iota %from_elements_2669, dim = 1 : (tensor<3xi64>) -> tensor + %7646 = stablehlo.concatenate %7644, %7645, dim = 2 : (tensor, tensor) -> tensor + %7647 = "stablehlo.scatter"(%7584, %7646, %7643) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7648 = stablehlo.slice %7519 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7649 = stablehlo.reshape %7648 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7650 = stablehlo.custom_call @byteir.non_zero(%7649) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2670 = tensor.dim %7650, %c0 : tensor + %7651 = arith.index_cast %dim_2670 : index to i64 + %from_elements_2671 = tensor.from_elements %7651, %c1_i64 : tensor<2xi64> + %7652 = stablehlo.real_dynamic_slice %7650, %c_22, %from_elements_2671, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2672 = tensor.dim %7652, %c0 : tensor + %7653 = arith.index_cast %dim_2672 : index to i64 + %from_elements_2673 = tensor.from_elements %7653 : tensor<1xi64> + %7654 = stablehlo.dynamic_reshape %7652, %from_elements_2673 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2674 = tensor.from_elements %7651, %c2_i64 : tensor<2xi64> + %7655 = stablehlo.real_dynamic_slice %7650, %c_24, %from_elements_2674, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2675 = tensor.dim %7655, %c0 : tensor + %7656 = arith.index_cast %dim_2675 : index to i64 + %from_elements_2676 = tensor.from_elements %7656 : tensor<1xi64> + %7657 = stablehlo.dynamic_reshape %7655, %from_elements_2676 : (tensor, tensor<1xi64>) -> tensor + %dim_2677 = tensor.dim %7657, %c0 : tensor + %7658 = arith.index_cast %dim_2677 : index to i64 + %from_elements_2678 = tensor.from_elements %7658, %c1_i64 : tensor<2xi64> + %7659 = stablehlo.dynamic_reshape %7657, %from_elements_2678 : (tensor, tensor<2xi64>) -> tensor + %dim_2679 = tensor.dim %7659, %c0 : tensor + %7660 = arith.index_cast %dim_2679 : index to i64 + %from_elements_2680 = tensor.from_elements %c1_i64, %7660, %c4096_i64 : tensor<3xi64> + %7661 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2680, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2681 = tensor.dim %7661, %c1 : tensor<1x?x4096xi64> + %7662 = arith.index_cast %dim_2681 : index to i64 + %from_elements_2682 = tensor.from_elements %c1_i64, %7662, %c4096_i64, %c1_i64 : tensor<4xi64> + %7663 = stablehlo.dynamic_reshape %7661, %from_elements_2682 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7664 = stablehlo.dynamic_broadcast_in_dim %7659, %from_elements_2680, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2683 = tensor.dim %7664, %c1 : tensor<1x?x4096xi64> + %7665 = arith.index_cast %dim_2683 : index to i64 + %from_elements_2684 = tensor.from_elements %c1_i64, %7665, %c4096_i64, %c1_i64 : tensor<4xi64> + %7666 = stablehlo.dynamic_reshape %7664, %from_elements_2684 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7667 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2680, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2685 = tensor.dim %7667, %c1 : tensor<1x?x4096xi64> + %7668 = arith.index_cast %dim_2685 : index to i64 + %from_elements_2686 = tensor.from_elements %c1_i64, %7668, %c4096_i64, %c1_i64 : tensor<4xi64> + %7669 = stablehlo.dynamic_reshape %7667, %from_elements_2686 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7670 = stablehlo.concatenate %7663, %7666, %7669, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7671 = "stablehlo.gather"(%7530, %7670) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7672 = shape.shape_of %7671 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7673 = shape.num_elements %7672 : tensor<3xindex> -> index + %7674 = stablehlo.compute_reshape_shape %7673, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7675 = stablehlo.dynamic_reshape %7671, %7674 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7676 = stablehlo.dot %7675, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7677 = stablehlo.logistic %7676 : tensor + %7678 = shape.shape_of %7677 : tensor -> tensor<2xindex> + %7679 = shape.shape_of %7676 : tensor -> tensor<2xindex> + %7680 = shape.cstr_broadcastable %7678, %7679 : tensor<2xindex>, tensor<2xindex> + %7681 = shape.assuming %7680 -> (tensor) { + %19688 = shape.broadcast %7678, %7679 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7677, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7676, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7682 = shape.shape_of %7681 : tensor -> tensor<2xindex> + %7683 = shape.cstr_broadcastable %7682, %7679 : tensor<2xindex>, tensor<2xindex> + %7684 = shape.assuming %7683 -> (tensor) { + %19688 = shape.broadcast %7682, %7679 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7681, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7676, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7685 = stablehlo.dot %7684, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2687 = tensor.dim %7657, %c0 : tensor + %7686 = arith.index_cast %dim_2687 : index to i64 + %from_elements_2688 = tensor.from_elements %7686, %c1_i64 : tensor<2xi64> + %7687 = stablehlo.dynamic_reshape %7657, %from_elements_2688 : (tensor, tensor<2xi64>) -> tensor + %dim_2689 = tensor.dim %7654, %c0 : tensor + %7688 = arith.index_cast %dim_2689 : index to i64 + %from_elements_2690 = tensor.from_elements %7688, %c1_i64 : tensor<2xi64> + %7689 = stablehlo.dynamic_reshape %7654, %from_elements_2690 : (tensor, tensor<2xi64>) -> tensor + %7690 = stablehlo.concatenate %7687, %7689, dim = 1 : (tensor, tensor) -> tensor + %7691 = "stablehlo.gather"(%7559, %7690) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7692 = shape.shape_of %7685 : tensor -> tensor<2xindex> + %7693 = shape.shape_of %7691 : tensor -> tensor<2xindex> + %7694 = shape.cstr_broadcastable %7692, %7693 : tensor<2xindex>, tensor<2xindex> + %7695 = shape.assuming %7694 -> (tensor) { + %19688 = shape.broadcast %7692, %7693 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7685, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7691, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7696 = shape.shape_of %7695 : tensor -> tensor<2xindex> + %7697 = stablehlo.dynamic_broadcast_in_dim %7695, %7696, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7698 = stablehlo.dynamic_broadcast_in_dim %213, %7696, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7699 = stablehlo.multiply %7697, %7698 : tensor + %dim_2691 = tensor.dim %7659, %c0 : tensor + %7700 = arith.index_cast %dim_2691 : index to i64 + %dim_2692 = tensor.dim %7695, %c0 : tensor + %7701 = arith.index_cast %dim_2692 : index to i64 + %7702 = arith.maxsi %7700, %7701 : i64 + %7703 = arith.index_cast %7702 : i64 to index + %from_elements_2693 = tensor.from_elements %7703, %c4096 : tensor<2xindex> + %7704 = stablehlo.dynamic_broadcast_in_dim %7659, %from_elements_2693, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2694 = tensor.dim %7704, %c0 : tensor + %7705 = arith.index_cast %dim_2694 : index to i64 + %from_elements_2695 = tensor.from_elements %7705, %c4096_i64 : tensor<2xi64> + %7706 = stablehlo.real_dynamic_slice %7699, %c_22, %from_elements_2695, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2696 = tensor.from_elements %7705, %c4096_i64, %c1_i64 : tensor<3xi64> + %7707 = stablehlo.dynamic_reshape %7704, %from_elements_2696 : (tensor, tensor<3xi64>) -> tensor + %7708 = stablehlo.dynamic_iota %from_elements_2696, dim = 1 : (tensor<3xi64>) -> tensor + %7709 = stablehlo.concatenate %7707, %7708, dim = 2 : (tensor, tensor) -> tensor + %7710 = "stablehlo.scatter"(%7647, %7709, %7706) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7711 = stablehlo.slice %7519 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7712 = stablehlo.reshape %7711 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7713 = stablehlo.custom_call @byteir.non_zero(%7712) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2697 = tensor.dim %7713, %c0 : tensor + %7714 = arith.index_cast %dim_2697 : index to i64 + %from_elements_2698 = tensor.from_elements %7714, %c1_i64 : tensor<2xi64> + %7715 = stablehlo.real_dynamic_slice %7713, %c_22, %from_elements_2698, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2699 = tensor.dim %7715, %c0 : tensor + %7716 = arith.index_cast %dim_2699 : index to i64 + %from_elements_2700 = tensor.from_elements %7716 : tensor<1xi64> + %7717 = stablehlo.dynamic_reshape %7715, %from_elements_2700 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2701 = tensor.from_elements %7714, %c2_i64 : tensor<2xi64> + %7718 = stablehlo.real_dynamic_slice %7713, %c_24, %from_elements_2701, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2702 = tensor.dim %7718, %c0 : tensor + %7719 = arith.index_cast %dim_2702 : index to i64 + %from_elements_2703 = tensor.from_elements %7719 : tensor<1xi64> + %7720 = stablehlo.dynamic_reshape %7718, %from_elements_2703 : (tensor, tensor<1xi64>) -> tensor + %dim_2704 = tensor.dim %7720, %c0 : tensor + %7721 = arith.index_cast %dim_2704 : index to i64 + %from_elements_2705 = tensor.from_elements %7721, %c1_i64 : tensor<2xi64> + %7722 = stablehlo.dynamic_reshape %7720, %from_elements_2705 : (tensor, tensor<2xi64>) -> tensor + %dim_2706 = tensor.dim %7722, %c0 : tensor + %7723 = arith.index_cast %dim_2706 : index to i64 + %from_elements_2707 = tensor.from_elements %c1_i64, %7723, %c4096_i64 : tensor<3xi64> + %7724 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2707, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2708 = tensor.dim %7724, %c1 : tensor<1x?x4096xi64> + %7725 = arith.index_cast %dim_2708 : index to i64 + %from_elements_2709 = tensor.from_elements %c1_i64, %7725, %c4096_i64, %c1_i64 : tensor<4xi64> + %7726 = stablehlo.dynamic_reshape %7724, %from_elements_2709 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7727 = stablehlo.dynamic_broadcast_in_dim %7722, %from_elements_2707, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2710 = tensor.dim %7727, %c1 : tensor<1x?x4096xi64> + %7728 = arith.index_cast %dim_2710 : index to i64 + %from_elements_2711 = tensor.from_elements %c1_i64, %7728, %c4096_i64, %c1_i64 : tensor<4xi64> + %7729 = stablehlo.dynamic_reshape %7727, %from_elements_2711 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7730 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2707, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2712 = tensor.dim %7730, %c1 : tensor<1x?x4096xi64> + %7731 = arith.index_cast %dim_2712 : index to i64 + %from_elements_2713 = tensor.from_elements %c1_i64, %7731, %c4096_i64, %c1_i64 : tensor<4xi64> + %7732 = stablehlo.dynamic_reshape %7730, %from_elements_2713 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7733 = stablehlo.concatenate %7726, %7729, %7732, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7734 = "stablehlo.gather"(%7530, %7733) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7735 = shape.shape_of %7734 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7736 = shape.num_elements %7735 : tensor<3xindex> -> index + %7737 = stablehlo.compute_reshape_shape %7736, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7738 = stablehlo.dynamic_reshape %7734, %7737 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7739 = stablehlo.dot %7738, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7740 = stablehlo.logistic %7739 : tensor + %7741 = shape.shape_of %7740 : tensor -> tensor<2xindex> + %7742 = shape.shape_of %7739 : tensor -> tensor<2xindex> + %7743 = shape.cstr_broadcastable %7741, %7742 : tensor<2xindex>, tensor<2xindex> + %7744 = shape.assuming %7743 -> (tensor) { + %19688 = shape.broadcast %7741, %7742 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7740, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7739, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7745 = shape.shape_of %7744 : tensor -> tensor<2xindex> + %7746 = shape.cstr_broadcastable %7745, %7742 : tensor<2xindex>, tensor<2xindex> + %7747 = shape.assuming %7746 -> (tensor) { + %19688 = shape.broadcast %7745, %7742 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7744, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7739, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7748 = stablehlo.dot %7747, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2714 = tensor.dim %7720, %c0 : tensor + %7749 = arith.index_cast %dim_2714 : index to i64 + %from_elements_2715 = tensor.from_elements %7749, %c1_i64 : tensor<2xi64> + %7750 = stablehlo.dynamic_reshape %7720, %from_elements_2715 : (tensor, tensor<2xi64>) -> tensor + %dim_2716 = tensor.dim %7717, %c0 : tensor + %7751 = arith.index_cast %dim_2716 : index to i64 + %from_elements_2717 = tensor.from_elements %7751, %c1_i64 : tensor<2xi64> + %7752 = stablehlo.dynamic_reshape %7717, %from_elements_2717 : (tensor, tensor<2xi64>) -> tensor + %7753 = stablehlo.concatenate %7750, %7752, dim = 1 : (tensor, tensor) -> tensor + %7754 = "stablehlo.gather"(%7559, %7753) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7755 = shape.shape_of %7748 : tensor -> tensor<2xindex> + %7756 = shape.shape_of %7754 : tensor -> tensor<2xindex> + %7757 = shape.cstr_broadcastable %7755, %7756 : tensor<2xindex>, tensor<2xindex> + %7758 = shape.assuming %7757 -> (tensor) { + %19688 = shape.broadcast %7755, %7756 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7748, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7754, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7759 = shape.shape_of %7758 : tensor -> tensor<2xindex> + %7760 = stablehlo.dynamic_broadcast_in_dim %7758, %7759, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7761 = stablehlo.dynamic_broadcast_in_dim %213, %7759, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7762 = stablehlo.multiply %7760, %7761 : tensor + %dim_2718 = tensor.dim %7722, %c0 : tensor + %7763 = arith.index_cast %dim_2718 : index to i64 + %dim_2719 = tensor.dim %7758, %c0 : tensor + %7764 = arith.index_cast %dim_2719 : index to i64 + %7765 = arith.maxsi %7763, %7764 : i64 + %7766 = arith.index_cast %7765 : i64 to index + %from_elements_2720 = tensor.from_elements %7766, %c4096 : tensor<2xindex> + %7767 = stablehlo.dynamic_broadcast_in_dim %7722, %from_elements_2720, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2721 = tensor.dim %7767, %c0 : tensor + %7768 = arith.index_cast %dim_2721 : index to i64 + %from_elements_2722 = tensor.from_elements %7768, %c4096_i64 : tensor<2xi64> + %7769 = stablehlo.real_dynamic_slice %7762, %c_22, %from_elements_2722, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2723 = tensor.from_elements %7768, %c4096_i64, %c1_i64 : tensor<3xi64> + %7770 = stablehlo.dynamic_reshape %7767, %from_elements_2723 : (tensor, tensor<3xi64>) -> tensor + %7771 = stablehlo.dynamic_iota %from_elements_2723, dim = 1 : (tensor<3xi64>) -> tensor + %7772 = stablehlo.concatenate %7770, %7771, dim = 2 : (tensor, tensor) -> tensor + %7773 = "stablehlo.scatter"(%7710, %7772, %7769) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7774 = stablehlo.slice %7519 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7775 = stablehlo.reshape %7774 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7776 = stablehlo.custom_call @byteir.non_zero(%7775) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2724 = tensor.dim %7776, %c0 : tensor + %7777 = arith.index_cast %dim_2724 : index to i64 + %from_elements_2725 = tensor.from_elements %7777, %c1_i64 : tensor<2xi64> + %7778 = stablehlo.real_dynamic_slice %7776, %c_22, %from_elements_2725, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2726 = tensor.dim %7778, %c0 : tensor + %7779 = arith.index_cast %dim_2726 : index to i64 + %from_elements_2727 = tensor.from_elements %7779 : tensor<1xi64> + %7780 = stablehlo.dynamic_reshape %7778, %from_elements_2727 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2728 = tensor.from_elements %7777, %c2_i64 : tensor<2xi64> + %7781 = stablehlo.real_dynamic_slice %7776, %c_24, %from_elements_2728, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2729 = tensor.dim %7781, %c0 : tensor + %7782 = arith.index_cast %dim_2729 : index to i64 + %from_elements_2730 = tensor.from_elements %7782 : tensor<1xi64> + %7783 = stablehlo.dynamic_reshape %7781, %from_elements_2730 : (tensor, tensor<1xi64>) -> tensor + %dim_2731 = tensor.dim %7783, %c0 : tensor + %7784 = arith.index_cast %dim_2731 : index to i64 + %from_elements_2732 = tensor.from_elements %7784, %c1_i64 : tensor<2xi64> + %7785 = stablehlo.dynamic_reshape %7783, %from_elements_2732 : (tensor, tensor<2xi64>) -> tensor + %dim_2733 = tensor.dim %7785, %c0 : tensor + %7786 = arith.index_cast %dim_2733 : index to i64 + %from_elements_2734 = tensor.from_elements %c1_i64, %7786, %c4096_i64 : tensor<3xi64> + %7787 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2734, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2735 = tensor.dim %7787, %c1 : tensor<1x?x4096xi64> + %7788 = arith.index_cast %dim_2735 : index to i64 + %from_elements_2736 = tensor.from_elements %c1_i64, %7788, %c4096_i64, %c1_i64 : tensor<4xi64> + %7789 = stablehlo.dynamic_reshape %7787, %from_elements_2736 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7790 = stablehlo.dynamic_broadcast_in_dim %7785, %from_elements_2734, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2737 = tensor.dim %7790, %c1 : tensor<1x?x4096xi64> + %7791 = arith.index_cast %dim_2737 : index to i64 + %from_elements_2738 = tensor.from_elements %c1_i64, %7791, %c4096_i64, %c1_i64 : tensor<4xi64> + %7792 = stablehlo.dynamic_reshape %7790, %from_elements_2738 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7793 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2734, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2739 = tensor.dim %7793, %c1 : tensor<1x?x4096xi64> + %7794 = arith.index_cast %dim_2739 : index to i64 + %from_elements_2740 = tensor.from_elements %c1_i64, %7794, %c4096_i64, %c1_i64 : tensor<4xi64> + %7795 = stablehlo.dynamic_reshape %7793, %from_elements_2740 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7796 = stablehlo.concatenate %7789, %7792, %7795, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7797 = "stablehlo.gather"(%7530, %7796) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7798 = shape.shape_of %7797 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7799 = shape.num_elements %7798 : tensor<3xindex> -> index + %7800 = stablehlo.compute_reshape_shape %7799, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7801 = stablehlo.dynamic_reshape %7797, %7800 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7802 = stablehlo.dot %7801, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7803 = stablehlo.logistic %7802 : tensor + %7804 = shape.shape_of %7803 : tensor -> tensor<2xindex> + %7805 = shape.shape_of %7802 : tensor -> tensor<2xindex> + %7806 = shape.cstr_broadcastable %7804, %7805 : tensor<2xindex>, tensor<2xindex> + %7807 = shape.assuming %7806 -> (tensor) { + %19688 = shape.broadcast %7804, %7805 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7803, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7802, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7808 = shape.shape_of %7807 : tensor -> tensor<2xindex> + %7809 = shape.cstr_broadcastable %7808, %7805 : tensor<2xindex>, tensor<2xindex> + %7810 = shape.assuming %7809 -> (tensor) { + %19688 = shape.broadcast %7808, %7805 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7807, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7802, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7811 = stablehlo.dot %7810, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2741 = tensor.dim %7783, %c0 : tensor + %7812 = arith.index_cast %dim_2741 : index to i64 + %from_elements_2742 = tensor.from_elements %7812, %c1_i64 : tensor<2xi64> + %7813 = stablehlo.dynamic_reshape %7783, %from_elements_2742 : (tensor, tensor<2xi64>) -> tensor + %dim_2743 = tensor.dim %7780, %c0 : tensor + %7814 = arith.index_cast %dim_2743 : index to i64 + %from_elements_2744 = tensor.from_elements %7814, %c1_i64 : tensor<2xi64> + %7815 = stablehlo.dynamic_reshape %7780, %from_elements_2744 : (tensor, tensor<2xi64>) -> tensor + %7816 = stablehlo.concatenate %7813, %7815, dim = 1 : (tensor, tensor) -> tensor + %7817 = "stablehlo.gather"(%7559, %7816) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7818 = shape.shape_of %7811 : tensor -> tensor<2xindex> + %7819 = shape.shape_of %7817 : tensor -> tensor<2xindex> + %7820 = shape.cstr_broadcastable %7818, %7819 : tensor<2xindex>, tensor<2xindex> + %7821 = shape.assuming %7820 -> (tensor) { + %19688 = shape.broadcast %7818, %7819 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7811, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7817, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7822 = shape.shape_of %7821 : tensor -> tensor<2xindex> + %7823 = stablehlo.dynamic_broadcast_in_dim %7821, %7822, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7824 = stablehlo.dynamic_broadcast_in_dim %213, %7822, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7825 = stablehlo.multiply %7823, %7824 : tensor + %dim_2745 = tensor.dim %7785, %c0 : tensor + %7826 = arith.index_cast %dim_2745 : index to i64 + %dim_2746 = tensor.dim %7821, %c0 : tensor + %7827 = arith.index_cast %dim_2746 : index to i64 + %7828 = arith.maxsi %7826, %7827 : i64 + %7829 = arith.index_cast %7828 : i64 to index + %from_elements_2747 = tensor.from_elements %7829, %c4096 : tensor<2xindex> + %7830 = stablehlo.dynamic_broadcast_in_dim %7785, %from_elements_2747, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2748 = tensor.dim %7830, %c0 : tensor + %7831 = arith.index_cast %dim_2748 : index to i64 + %from_elements_2749 = tensor.from_elements %7831, %c4096_i64 : tensor<2xi64> + %7832 = stablehlo.real_dynamic_slice %7825, %c_22, %from_elements_2749, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2750 = tensor.from_elements %7831, %c4096_i64, %c1_i64 : tensor<3xi64> + %7833 = stablehlo.dynamic_reshape %7830, %from_elements_2750 : (tensor, tensor<3xi64>) -> tensor + %7834 = stablehlo.dynamic_iota %from_elements_2750, dim = 1 : (tensor<3xi64>) -> tensor + %7835 = stablehlo.concatenate %7833, %7834, dim = 2 : (tensor, tensor) -> tensor + %7836 = "stablehlo.scatter"(%7773, %7835, %7832) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7837 = stablehlo.slice %7519 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7838 = stablehlo.reshape %7837 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7839 = stablehlo.custom_call @byteir.non_zero(%7838) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2751 = tensor.dim %7839, %c0 : tensor + %7840 = arith.index_cast %dim_2751 : index to i64 + %from_elements_2752 = tensor.from_elements %7840, %c1_i64 : tensor<2xi64> + %7841 = stablehlo.real_dynamic_slice %7839, %c_22, %from_elements_2752, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2753 = tensor.dim %7841, %c0 : tensor + %7842 = arith.index_cast %dim_2753 : index to i64 + %from_elements_2754 = tensor.from_elements %7842 : tensor<1xi64> + %7843 = stablehlo.dynamic_reshape %7841, %from_elements_2754 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2755 = tensor.from_elements %7840, %c2_i64 : tensor<2xi64> + %7844 = stablehlo.real_dynamic_slice %7839, %c_24, %from_elements_2755, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2756 = tensor.dim %7844, %c0 : tensor + %7845 = arith.index_cast %dim_2756 : index to i64 + %from_elements_2757 = tensor.from_elements %7845 : tensor<1xi64> + %7846 = stablehlo.dynamic_reshape %7844, %from_elements_2757 : (tensor, tensor<1xi64>) -> tensor + %dim_2758 = tensor.dim %7846, %c0 : tensor + %7847 = arith.index_cast %dim_2758 : index to i64 + %from_elements_2759 = tensor.from_elements %7847, %c1_i64 : tensor<2xi64> + %7848 = stablehlo.dynamic_reshape %7846, %from_elements_2759 : (tensor, tensor<2xi64>) -> tensor + %dim_2760 = tensor.dim %7848, %c0 : tensor + %7849 = arith.index_cast %dim_2760 : index to i64 + %from_elements_2761 = tensor.from_elements %c1_i64, %7849, %c4096_i64 : tensor<3xi64> + %7850 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2761, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2762 = tensor.dim %7850, %c1 : tensor<1x?x4096xi64> + %7851 = arith.index_cast %dim_2762 : index to i64 + %from_elements_2763 = tensor.from_elements %c1_i64, %7851, %c4096_i64, %c1_i64 : tensor<4xi64> + %7852 = stablehlo.dynamic_reshape %7850, %from_elements_2763 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7853 = stablehlo.dynamic_broadcast_in_dim %7848, %from_elements_2761, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2764 = tensor.dim %7853, %c1 : tensor<1x?x4096xi64> + %7854 = arith.index_cast %dim_2764 : index to i64 + %from_elements_2765 = tensor.from_elements %c1_i64, %7854, %c4096_i64, %c1_i64 : tensor<4xi64> + %7855 = stablehlo.dynamic_reshape %7853, %from_elements_2765 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7856 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2761, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2766 = tensor.dim %7856, %c1 : tensor<1x?x4096xi64> + %7857 = arith.index_cast %dim_2766 : index to i64 + %from_elements_2767 = tensor.from_elements %c1_i64, %7857, %c4096_i64, %c1_i64 : tensor<4xi64> + %7858 = stablehlo.dynamic_reshape %7856, %from_elements_2767 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7859 = stablehlo.concatenate %7852, %7855, %7858, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7860 = "stablehlo.gather"(%7530, %7859) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7861 = shape.shape_of %7860 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7862 = shape.num_elements %7861 : tensor<3xindex> -> index + %7863 = stablehlo.compute_reshape_shape %7862, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7864 = stablehlo.dynamic_reshape %7860, %7863 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7865 = stablehlo.dot %7864, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7866 = stablehlo.logistic %7865 : tensor + %7867 = shape.shape_of %7866 : tensor -> tensor<2xindex> + %7868 = shape.shape_of %7865 : tensor -> tensor<2xindex> + %7869 = shape.cstr_broadcastable %7867, %7868 : tensor<2xindex>, tensor<2xindex> + %7870 = shape.assuming %7869 -> (tensor) { + %19688 = shape.broadcast %7867, %7868 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7866, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7865, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7871 = shape.shape_of %7870 : tensor -> tensor<2xindex> + %7872 = shape.cstr_broadcastable %7871, %7868 : tensor<2xindex>, tensor<2xindex> + %7873 = shape.assuming %7872 -> (tensor) { + %19688 = shape.broadcast %7871, %7868 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7870, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7865, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7874 = stablehlo.dot %7873, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2768 = tensor.dim %7846, %c0 : tensor + %7875 = arith.index_cast %dim_2768 : index to i64 + %from_elements_2769 = tensor.from_elements %7875, %c1_i64 : tensor<2xi64> + %7876 = stablehlo.dynamic_reshape %7846, %from_elements_2769 : (tensor, tensor<2xi64>) -> tensor + %dim_2770 = tensor.dim %7843, %c0 : tensor + %7877 = arith.index_cast %dim_2770 : index to i64 + %from_elements_2771 = tensor.from_elements %7877, %c1_i64 : tensor<2xi64> + %7878 = stablehlo.dynamic_reshape %7843, %from_elements_2771 : (tensor, tensor<2xi64>) -> tensor + %7879 = stablehlo.concatenate %7876, %7878, dim = 1 : (tensor, tensor) -> tensor + %7880 = "stablehlo.gather"(%7559, %7879) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7881 = shape.shape_of %7874 : tensor -> tensor<2xindex> + %7882 = shape.shape_of %7880 : tensor -> tensor<2xindex> + %7883 = shape.cstr_broadcastable %7881, %7882 : tensor<2xindex>, tensor<2xindex> + %7884 = shape.assuming %7883 -> (tensor) { + %19688 = shape.broadcast %7881, %7882 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7874, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7880, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7885 = shape.shape_of %7884 : tensor -> tensor<2xindex> + %7886 = stablehlo.dynamic_broadcast_in_dim %7884, %7885, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7887 = stablehlo.dynamic_broadcast_in_dim %213, %7885, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7888 = stablehlo.multiply %7886, %7887 : tensor + %dim_2772 = tensor.dim %7848, %c0 : tensor + %7889 = arith.index_cast %dim_2772 : index to i64 + %dim_2773 = tensor.dim %7884, %c0 : tensor + %7890 = arith.index_cast %dim_2773 : index to i64 + %7891 = arith.maxsi %7889, %7890 : i64 + %7892 = arith.index_cast %7891 : i64 to index + %from_elements_2774 = tensor.from_elements %7892, %c4096 : tensor<2xindex> + %7893 = stablehlo.dynamic_broadcast_in_dim %7848, %from_elements_2774, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2775 = tensor.dim %7893, %c0 : tensor + %7894 = arith.index_cast %dim_2775 : index to i64 + %from_elements_2776 = tensor.from_elements %7894, %c4096_i64 : tensor<2xi64> + %7895 = stablehlo.real_dynamic_slice %7888, %c_22, %from_elements_2776, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2777 = tensor.from_elements %7894, %c4096_i64, %c1_i64 : tensor<3xi64> + %7896 = stablehlo.dynamic_reshape %7893, %from_elements_2777 : (tensor, tensor<3xi64>) -> tensor + %7897 = stablehlo.dynamic_iota %from_elements_2777, dim = 1 : (tensor<3xi64>) -> tensor + %7898 = stablehlo.concatenate %7896, %7897, dim = 2 : (tensor, tensor) -> tensor + %7899 = "stablehlo.scatter"(%7836, %7898, %7895) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7900 = stablehlo.slice %7519 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7901 = stablehlo.reshape %7900 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7902 = stablehlo.custom_call @byteir.non_zero(%7901) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2778 = tensor.dim %7902, %c0 : tensor + %7903 = arith.index_cast %dim_2778 : index to i64 + %from_elements_2779 = tensor.from_elements %7903, %c1_i64 : tensor<2xi64> + %7904 = stablehlo.real_dynamic_slice %7902, %c_22, %from_elements_2779, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2780 = tensor.dim %7904, %c0 : tensor + %7905 = arith.index_cast %dim_2780 : index to i64 + %from_elements_2781 = tensor.from_elements %7905 : tensor<1xi64> + %7906 = stablehlo.dynamic_reshape %7904, %from_elements_2781 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2782 = tensor.from_elements %7903, %c2_i64 : tensor<2xi64> + %7907 = stablehlo.real_dynamic_slice %7902, %c_24, %from_elements_2782, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2783 = tensor.dim %7907, %c0 : tensor + %7908 = arith.index_cast %dim_2783 : index to i64 + %from_elements_2784 = tensor.from_elements %7908 : tensor<1xi64> + %7909 = stablehlo.dynamic_reshape %7907, %from_elements_2784 : (tensor, tensor<1xi64>) -> tensor + %dim_2785 = tensor.dim %7909, %c0 : tensor + %7910 = arith.index_cast %dim_2785 : index to i64 + %from_elements_2786 = tensor.from_elements %7910, %c1_i64 : tensor<2xi64> + %7911 = stablehlo.dynamic_reshape %7909, %from_elements_2786 : (tensor, tensor<2xi64>) -> tensor + %dim_2787 = tensor.dim %7911, %c0 : tensor + %7912 = arith.index_cast %dim_2787 : index to i64 + %from_elements_2788 = tensor.from_elements %c1_i64, %7912, %c4096_i64 : tensor<3xi64> + %7913 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2788, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2789 = tensor.dim %7913, %c1 : tensor<1x?x4096xi64> + %7914 = arith.index_cast %dim_2789 : index to i64 + %from_elements_2790 = tensor.from_elements %c1_i64, %7914, %c4096_i64, %c1_i64 : tensor<4xi64> + %7915 = stablehlo.dynamic_reshape %7913, %from_elements_2790 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7916 = stablehlo.dynamic_broadcast_in_dim %7911, %from_elements_2788, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2791 = tensor.dim %7916, %c1 : tensor<1x?x4096xi64> + %7917 = arith.index_cast %dim_2791 : index to i64 + %from_elements_2792 = tensor.from_elements %c1_i64, %7917, %c4096_i64, %c1_i64 : tensor<4xi64> + %7918 = stablehlo.dynamic_reshape %7916, %from_elements_2792 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7919 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2788, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2793 = tensor.dim %7919, %c1 : tensor<1x?x4096xi64> + %7920 = arith.index_cast %dim_2793 : index to i64 + %from_elements_2794 = tensor.from_elements %c1_i64, %7920, %c4096_i64, %c1_i64 : tensor<4xi64> + %7921 = stablehlo.dynamic_reshape %7919, %from_elements_2794 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7922 = stablehlo.concatenate %7915, %7918, %7921, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7923 = "stablehlo.gather"(%7530, %7922) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7924 = shape.shape_of %7923 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7925 = shape.num_elements %7924 : tensor<3xindex> -> index + %7926 = stablehlo.compute_reshape_shape %7925, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7927 = stablehlo.dynamic_reshape %7923, %7926 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7928 = stablehlo.dot %7927, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7929 = stablehlo.logistic %7928 : tensor + %7930 = shape.shape_of %7929 : tensor -> tensor<2xindex> + %7931 = shape.shape_of %7928 : tensor -> tensor<2xindex> + %7932 = shape.cstr_broadcastable %7930, %7931 : tensor<2xindex>, tensor<2xindex> + %7933 = shape.assuming %7932 -> (tensor) { + %19688 = shape.broadcast %7930, %7931 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7929, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7928, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7934 = shape.shape_of %7933 : tensor -> tensor<2xindex> + %7935 = shape.cstr_broadcastable %7934, %7931 : tensor<2xindex>, tensor<2xindex> + %7936 = shape.assuming %7935 -> (tensor) { + %19688 = shape.broadcast %7934, %7931 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7933, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7928, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7937 = stablehlo.dot %7936, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2795 = tensor.dim %7909, %c0 : tensor + %7938 = arith.index_cast %dim_2795 : index to i64 + %from_elements_2796 = tensor.from_elements %7938, %c1_i64 : tensor<2xi64> + %7939 = stablehlo.dynamic_reshape %7909, %from_elements_2796 : (tensor, tensor<2xi64>) -> tensor + %dim_2797 = tensor.dim %7906, %c0 : tensor + %7940 = arith.index_cast %dim_2797 : index to i64 + %from_elements_2798 = tensor.from_elements %7940, %c1_i64 : tensor<2xi64> + %7941 = stablehlo.dynamic_reshape %7906, %from_elements_2798 : (tensor, tensor<2xi64>) -> tensor + %7942 = stablehlo.concatenate %7939, %7941, dim = 1 : (tensor, tensor) -> tensor + %7943 = "stablehlo.gather"(%7559, %7942) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %7944 = shape.shape_of %7937 : tensor -> tensor<2xindex> + %7945 = shape.shape_of %7943 : tensor -> tensor<2xindex> + %7946 = shape.cstr_broadcastable %7944, %7945 : tensor<2xindex>, tensor<2xindex> + %7947 = shape.assuming %7946 -> (tensor) { + %19688 = shape.broadcast %7944, %7945 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7937, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7943, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7948 = shape.shape_of %7947 : tensor -> tensor<2xindex> + %7949 = stablehlo.dynamic_broadcast_in_dim %7947, %7948, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %7950 = stablehlo.dynamic_broadcast_in_dim %213, %7948, dims = [] : (tensor, tensor<2xindex>) -> tensor + %7951 = stablehlo.multiply %7949, %7950 : tensor + %dim_2799 = tensor.dim %7911, %c0 : tensor + %7952 = arith.index_cast %dim_2799 : index to i64 + %dim_2800 = tensor.dim %7947, %c0 : tensor + %7953 = arith.index_cast %dim_2800 : index to i64 + %7954 = arith.maxsi %7952, %7953 : i64 + %7955 = arith.index_cast %7954 : i64 to index + %from_elements_2801 = tensor.from_elements %7955, %c4096 : tensor<2xindex> + %7956 = stablehlo.dynamic_broadcast_in_dim %7911, %from_elements_2801, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2802 = tensor.dim %7956, %c0 : tensor + %7957 = arith.index_cast %dim_2802 : index to i64 + %from_elements_2803 = tensor.from_elements %7957, %c4096_i64 : tensor<2xi64> + %7958 = stablehlo.real_dynamic_slice %7951, %c_22, %from_elements_2803, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2804 = tensor.from_elements %7957, %c4096_i64, %c1_i64 : tensor<3xi64> + %7959 = stablehlo.dynamic_reshape %7956, %from_elements_2804 : (tensor, tensor<3xi64>) -> tensor + %7960 = stablehlo.dynamic_iota %from_elements_2804, dim = 1 : (tensor<3xi64>) -> tensor + %7961 = stablehlo.concatenate %7959, %7960, dim = 2 : (tensor, tensor) -> tensor + %7962 = "stablehlo.scatter"(%7899, %7961, %7958) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %7963 = stablehlo.slice %7519 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %7964 = stablehlo.reshape %7963 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %7965 = stablehlo.custom_call @byteir.non_zero(%7964) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2805 = tensor.dim %7965, %c0 : tensor + %7966 = arith.index_cast %dim_2805 : index to i64 + %from_elements_2806 = tensor.from_elements %7966, %c1_i64 : tensor<2xi64> + %7967 = stablehlo.real_dynamic_slice %7965, %c_22, %from_elements_2806, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2807 = tensor.dim %7967, %c0 : tensor + %7968 = arith.index_cast %dim_2807 : index to i64 + %from_elements_2808 = tensor.from_elements %7968 : tensor<1xi64> + %7969 = stablehlo.dynamic_reshape %7967, %from_elements_2808 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2809 = tensor.from_elements %7966, %c2_i64 : tensor<2xi64> + %7970 = stablehlo.real_dynamic_slice %7965, %c_24, %from_elements_2809, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2810 = tensor.dim %7970, %c0 : tensor + %7971 = arith.index_cast %dim_2810 : index to i64 + %from_elements_2811 = tensor.from_elements %7971 : tensor<1xi64> + %7972 = stablehlo.dynamic_reshape %7970, %from_elements_2811 : (tensor, tensor<1xi64>) -> tensor + %dim_2812 = tensor.dim %7972, %c0 : tensor + %7973 = arith.index_cast %dim_2812 : index to i64 + %from_elements_2813 = tensor.from_elements %7973, %c1_i64 : tensor<2xi64> + %7974 = stablehlo.dynamic_reshape %7972, %from_elements_2813 : (tensor, tensor<2xi64>) -> tensor + %dim_2814 = tensor.dim %7974, %c0 : tensor + %7975 = arith.index_cast %dim_2814 : index to i64 + %from_elements_2815 = tensor.from_elements %c1_i64, %7975, %c4096_i64 : tensor<3xi64> + %7976 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2815, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2816 = tensor.dim %7976, %c1 : tensor<1x?x4096xi64> + %7977 = arith.index_cast %dim_2816 : index to i64 + %from_elements_2817 = tensor.from_elements %c1_i64, %7977, %c4096_i64, %c1_i64 : tensor<4xi64> + %7978 = stablehlo.dynamic_reshape %7976, %from_elements_2817 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7979 = stablehlo.dynamic_broadcast_in_dim %7974, %from_elements_2815, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2818 = tensor.dim %7979, %c1 : tensor<1x?x4096xi64> + %7980 = arith.index_cast %dim_2818 : index to i64 + %from_elements_2819 = tensor.from_elements %c1_i64, %7980, %c4096_i64, %c1_i64 : tensor<4xi64> + %7981 = stablehlo.dynamic_reshape %7979, %from_elements_2819 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7982 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2815, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2820 = tensor.dim %7982, %c1 : tensor<1x?x4096xi64> + %7983 = arith.index_cast %dim_2820 : index to i64 + %from_elements_2821 = tensor.from_elements %c1_i64, %7983, %c4096_i64, %c1_i64 : tensor<4xi64> + %7984 = stablehlo.dynamic_reshape %7982, %from_elements_2821 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %7985 = stablehlo.concatenate %7978, %7981, %7984, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %7986 = "stablehlo.gather"(%7530, %7985) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %7987 = shape.shape_of %7986 : tensor<1x?x4096xf32> -> tensor<3xindex> + %7988 = shape.num_elements %7987 : tensor<3xindex> -> index + %7989 = stablehlo.compute_reshape_shape %7988, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %7990 = stablehlo.dynamic_reshape %7986, %7989 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %7991 = stablehlo.dot %7990, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %7992 = stablehlo.logistic %7991 : tensor + %7993 = shape.shape_of %7992 : tensor -> tensor<2xindex> + %7994 = shape.shape_of %7991 : tensor -> tensor<2xindex> + %7995 = shape.cstr_broadcastable %7993, %7994 : tensor<2xindex>, tensor<2xindex> + %7996 = shape.assuming %7995 -> (tensor) { + %19688 = shape.broadcast %7993, %7994 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7992, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7991, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %7997 = shape.shape_of %7996 : tensor -> tensor<2xindex> + %7998 = shape.cstr_broadcastable %7997, %7994 : tensor<2xindex>, tensor<2xindex> + %7999 = shape.assuming %7998 -> (tensor) { + %19688 = shape.broadcast %7997, %7994 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %7996, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %7991, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8000 = stablehlo.dot %7999, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2822 = tensor.dim %7972, %c0 : tensor + %8001 = arith.index_cast %dim_2822 : index to i64 + %from_elements_2823 = tensor.from_elements %8001, %c1_i64 : tensor<2xi64> + %8002 = stablehlo.dynamic_reshape %7972, %from_elements_2823 : (tensor, tensor<2xi64>) -> tensor + %dim_2824 = tensor.dim %7969, %c0 : tensor + %8003 = arith.index_cast %dim_2824 : index to i64 + %from_elements_2825 = tensor.from_elements %8003, %c1_i64 : tensor<2xi64> + %8004 = stablehlo.dynamic_reshape %7969, %from_elements_2825 : (tensor, tensor<2xi64>) -> tensor + %8005 = stablehlo.concatenate %8002, %8004, dim = 1 : (tensor, tensor) -> tensor + %8006 = "stablehlo.gather"(%7559, %8005) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8007 = shape.shape_of %8000 : tensor -> tensor<2xindex> + %8008 = shape.shape_of %8006 : tensor -> tensor<2xindex> + %8009 = shape.cstr_broadcastable %8007, %8008 : tensor<2xindex>, tensor<2xindex> + %8010 = shape.assuming %8009 -> (tensor) { + %19688 = shape.broadcast %8007, %8008 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8000, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8006, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8011 = shape.shape_of %8010 : tensor -> tensor<2xindex> + %8012 = stablehlo.dynamic_broadcast_in_dim %8010, %8011, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8013 = stablehlo.dynamic_broadcast_in_dim %213, %8011, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8014 = stablehlo.multiply %8012, %8013 : tensor + %dim_2826 = tensor.dim %7974, %c0 : tensor + %8015 = arith.index_cast %dim_2826 : index to i64 + %dim_2827 = tensor.dim %8010, %c0 : tensor + %8016 = arith.index_cast %dim_2827 : index to i64 + %8017 = arith.maxsi %8015, %8016 : i64 + %8018 = arith.index_cast %8017 : i64 to index + %from_elements_2828 = tensor.from_elements %8018, %c4096 : tensor<2xindex> + %8019 = stablehlo.dynamic_broadcast_in_dim %7974, %from_elements_2828, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2829 = tensor.dim %8019, %c0 : tensor + %8020 = arith.index_cast %dim_2829 : index to i64 + %from_elements_2830 = tensor.from_elements %8020, %c4096_i64 : tensor<2xi64> + %8021 = stablehlo.real_dynamic_slice %8014, %c_22, %from_elements_2830, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2831 = tensor.from_elements %8020, %c4096_i64, %c1_i64 : tensor<3xi64> + %8022 = stablehlo.dynamic_reshape %8019, %from_elements_2831 : (tensor, tensor<3xi64>) -> tensor + %8023 = stablehlo.dynamic_iota %from_elements_2831, dim = 1 : (tensor<3xi64>) -> tensor + %8024 = stablehlo.concatenate %8022, %8023, dim = 2 : (tensor, tensor) -> tensor + %8025 = "stablehlo.scatter"(%7962, %8024, %8021) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8026 = stablehlo.reshape %8025 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %8027 = stablehlo.add %7492, %8026 : tensor<3x1x4096xf32> + %8028 = stablehlo.broadcast_in_dim %8027, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %8029 = stablehlo.power %8028, %15 : tensor<3x1x4096xf32> + %8030 = stablehlo.reduce(%8029 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %8031 = stablehlo.reshape %8030 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %8032 = stablehlo.broadcast_in_dim %8031, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %8033 = stablehlo.divide %8032, %21 : tensor<3x1x1xf32> + %8034 = stablehlo.broadcast_in_dim %8033, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %8035 = stablehlo.add %8034, %25 : tensor<3x1x1xf32> + %8036 = stablehlo.rsqrt %8035 : tensor<3x1x1xf32> + %8037 = stablehlo.broadcast_in_dim %8036, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %8038 = stablehlo.multiply %8028, %8037 : tensor<3x1x4096xf32> + %8039 = stablehlo.broadcast_in_dim %8038, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %8040 = stablehlo.multiply %8039, %31 : tensor<3x1x4096xf32> + %8041 = stablehlo.reshape %8040 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %8042 = stablehlo.dot %8041, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %8043 = stablehlo.reshape %8042 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %8044 = stablehlo.dot %8041, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %8045 = stablehlo.reshape %8044 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %8046 = stablehlo.reshape %8043 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %8047 = stablehlo.transpose %8046, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %8048 = stablehlo.reshape %8045 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %8049 = stablehlo.transpose %8048, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %8050 = stablehlo.slice %arg26 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %8051 = stablehlo.slice %arg27 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %8052 = "stablehlo.gather"(%8050, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %8053 = stablehlo.reshape %8052 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %8054 = "stablehlo.gather"(%8051, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %8055 = stablehlo.reshape %8054 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %8056 = stablehlo.broadcast_in_dim %8047, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %8057 = stablehlo.broadcast_in_dim %8053, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %8058 = stablehlo.multiply %8056, %8057 : tensor<3x32x1x128xf32> + %8059 = stablehlo.slice %8047 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %8060 = stablehlo.slice %8047 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %8061 = stablehlo.negate %8060 : tensor<3x32x1x64xf32> + %8062 = stablehlo.concatenate %8061, %8059, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %8063 = stablehlo.broadcast_in_dim %8062, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %8064 = stablehlo.broadcast_in_dim %8055, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %8065 = stablehlo.multiply %8063, %8064 : tensor<3x32x1x128xf32> + %8066 = stablehlo.add %8058, %8065 : tensor<3x32x1x128xf32> + %8067 = stablehlo.broadcast_in_dim %8049, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %8068 = stablehlo.broadcast_in_dim %8053, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %8069 = stablehlo.multiply %8067, %8068 : tensor<3x8x1x128xf32> + %8070 = stablehlo.slice %8049 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %8071 = stablehlo.slice %8049 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %8072 = stablehlo.negate %8071 : tensor<3x8x1x64xf32> + %8073 = stablehlo.concatenate %8072, %8070, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %8074 = stablehlo.broadcast_in_dim %8073, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %8075 = stablehlo.broadcast_in_dim %8055, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %8076 = stablehlo.multiply %8074, %8075 : tensor<3x8x1x128xf32> + %8077 = stablehlo.add %8069, %8076 : tensor<3x8x1x128xf32> + %8078 = stablehlo.concatenate %arg91, %8077, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %8079 = stablehlo.concatenate %arg92, %8049, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %8080 = stablehlo.reshape %8078 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %8081 = stablehlo.broadcast_in_dim %8080, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %8082 = stablehlo.reshape %8081 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %8083 = stablehlo.reshape %8079 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %8084 = stablehlo.broadcast_in_dim %8083, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %8085 = stablehlo.reshape %8084 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %8086 = stablehlo.transpose %8082, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %8087 = stablehlo.reshape %8066 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %8088 = stablehlo.reshape %8086 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %8089 = stablehlo.broadcast_in_dim %8088, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %8090 = stablehlo.dot_general %8087, %8089, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %8091 = stablehlo.reshape %8090 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %8092 = stablehlo.broadcast_in_dim %8091, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %8093 = stablehlo.divide %8092, %89 : tensor<3x32x1x8xf32> + %8094 = stablehlo.custom_call @byteir.softmax(%8093) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %8095 = stablehlo.reshape %8094 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %8096 = stablehlo.reshape %8085 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %8097 = stablehlo.broadcast_in_dim %8096, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %8098 = stablehlo.dot_general %8095, %8097, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %8099 = stablehlo.reshape %8098 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %8100 = stablehlo.transpose %8099, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %8101 = stablehlo.reshape %8100 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %8102 = stablehlo.reshape %8101 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %8103 = stablehlo.dot %8102, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %8104 = stablehlo.reshape %8103 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %8105 = stablehlo.add %8027, %8104 : tensor<3x1x4096xf32> + %8106 = stablehlo.broadcast_in_dim %8105, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %8107 = stablehlo.power %8106, %15 : tensor<3x1x4096xf32> + %8108 = stablehlo.reduce(%8107 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %8109 = stablehlo.reshape %8108 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %8110 = stablehlo.broadcast_in_dim %8109, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %8111 = stablehlo.divide %8110, %21 : tensor<3x1x1xf32> + %8112 = stablehlo.broadcast_in_dim %8111, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %8113 = stablehlo.add %8112, %25 : tensor<3x1x1xf32> + %8114 = stablehlo.rsqrt %8113 : tensor<3x1x1xf32> + %8115 = stablehlo.broadcast_in_dim %8114, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %8116 = stablehlo.multiply %8106, %8115 : tensor<3x1x4096xf32> + %8117 = stablehlo.broadcast_in_dim %8116, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %8118 = stablehlo.multiply %8117, %31 : tensor<3x1x4096xf32> + %8119 = stablehlo.reshape %8118 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %8120 = stablehlo.dot %8119, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %8121 = stablehlo.custom_call @byteir.softmax(%8120) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %8122:2 = stablehlo.custom_call @byteir.top_k(%8121) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %8123 = stablehlo.reduce(%8122#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %8124 = stablehlo.reshape %8123 : (tensor<3xf32>) -> tensor<3x1xf32> + %8125 = stablehlo.broadcast_in_dim %8122#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %8126 = stablehlo.broadcast_in_dim %8124, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %8127 = stablehlo.divide %8125, %8126 : tensor<3x2xf32> + %8128 = stablehlo.reshape %8122#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %8129 = stablehlo.broadcast_in_dim %8128, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %8130 = stablehlo.compare EQ, %8129, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %8131 = stablehlo.convert %8130 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %8132 = stablehlo.transpose %8131, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %8133 = stablehlo.slice %8132 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8134 = stablehlo.reshape %8133 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8135 = stablehlo.custom_call @byteir.non_zero(%8134) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2832 = tensor.dim %8135, %c0 : tensor + %8136 = arith.index_cast %dim_2832 : index to i64 + %from_elements_2833 = tensor.from_elements %8136, %c1_i64 : tensor<2xi64> + %8137 = stablehlo.real_dynamic_slice %8135, %c_22, %from_elements_2833, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2834 = tensor.dim %8137, %c0 : tensor + %8138 = arith.index_cast %dim_2834 : index to i64 + %from_elements_2835 = tensor.from_elements %8138 : tensor<1xi64> + %8139 = stablehlo.dynamic_reshape %8137, %from_elements_2835 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2836 = tensor.from_elements %8136, %c2_i64 : tensor<2xi64> + %8140 = stablehlo.real_dynamic_slice %8135, %c_24, %from_elements_2836, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2837 = tensor.dim %8140, %c0 : tensor + %8141 = arith.index_cast %dim_2837 : index to i64 + %from_elements_2838 = tensor.from_elements %8141 : tensor<1xi64> + %8142 = stablehlo.dynamic_reshape %8140, %from_elements_2838 : (tensor, tensor<1xi64>) -> tensor + %8143 = stablehlo.reshape %8119 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_2839 = tensor.dim %8142, %c0 : tensor + %8144 = arith.index_cast %dim_2839 : index to i64 + %from_elements_2840 = tensor.from_elements %8144, %c1_i64 : tensor<2xi64> + %8145 = stablehlo.dynamic_reshape %8142, %from_elements_2840 : (tensor, tensor<2xi64>) -> tensor + %dim_2841 = tensor.dim %8145, %c0 : tensor + %8146 = arith.index_cast %dim_2841 : index to i64 + %from_elements_2842 = tensor.from_elements %c1_i64, %8146, %c4096_i64 : tensor<3xi64> + %8147 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2842, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2843 = tensor.dim %8147, %c1 : tensor<1x?x4096xi64> + %8148 = arith.index_cast %dim_2843 : index to i64 + %from_elements_2844 = tensor.from_elements %c1_i64, %8148, %c4096_i64, %c1_i64 : tensor<4xi64> + %8149 = stablehlo.dynamic_reshape %8147, %from_elements_2844 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8150 = stablehlo.dynamic_broadcast_in_dim %8145, %from_elements_2842, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2845 = tensor.dim %8150, %c1 : tensor<1x?x4096xi64> + %8151 = arith.index_cast %dim_2845 : index to i64 + %from_elements_2846 = tensor.from_elements %c1_i64, %8151, %c4096_i64, %c1_i64 : tensor<4xi64> + %8152 = stablehlo.dynamic_reshape %8150, %from_elements_2846 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8153 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2842, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2847 = tensor.dim %8153, %c1 : tensor<1x?x4096xi64> + %8154 = arith.index_cast %dim_2847 : index to i64 + %from_elements_2848 = tensor.from_elements %c1_i64, %8154, %c4096_i64, %c1_i64 : tensor<4xi64> + %8155 = stablehlo.dynamic_reshape %8153, %from_elements_2848 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8156 = stablehlo.concatenate %8149, %8152, %8155, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8157 = "stablehlo.gather"(%8143, %8156) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8158 = shape.shape_of %8157 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8159 = shape.num_elements %8158 : tensor<3xindex> -> index + %8160 = stablehlo.compute_reshape_shape %8159, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8161 = stablehlo.dynamic_reshape %8157, %8160 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8162 = stablehlo.dot %8161, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8163 = stablehlo.logistic %8162 : tensor + %8164 = shape.shape_of %8163 : tensor -> tensor<2xindex> + %8165 = shape.shape_of %8162 : tensor -> tensor<2xindex> + %8166 = shape.cstr_broadcastable %8164, %8165 : tensor<2xindex>, tensor<2xindex> + %8167 = shape.assuming %8166 -> (tensor) { + %19688 = shape.broadcast %8164, %8165 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8163, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8162, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8168 = shape.shape_of %8167 : tensor -> tensor<2xindex> + %8169 = shape.cstr_broadcastable %8168, %8165 : tensor<2xindex>, tensor<2xindex> + %8170 = shape.assuming %8169 -> (tensor) { + %19688 = shape.broadcast %8168, %8165 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8167, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8162, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8171 = stablehlo.dot %8170, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %8172 = stablehlo.reshape %8127 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_2849 = tensor.dim %8142, %c0 : tensor + %8173 = arith.index_cast %dim_2849 : index to i64 + %from_elements_2850 = tensor.from_elements %8173, %c1_i64 : tensor<2xi64> + %8174 = stablehlo.dynamic_reshape %8142, %from_elements_2850 : (tensor, tensor<2xi64>) -> tensor + %dim_2851 = tensor.dim %8139, %c0 : tensor + %8175 = arith.index_cast %dim_2851 : index to i64 + %from_elements_2852 = tensor.from_elements %8175, %c1_i64 : tensor<2xi64> + %8176 = stablehlo.dynamic_reshape %8139, %from_elements_2852 : (tensor, tensor<2xi64>) -> tensor + %8177 = stablehlo.concatenate %8174, %8176, dim = 1 : (tensor, tensor) -> tensor + %8178 = "stablehlo.gather"(%8172, %8177) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8179 = shape.shape_of %8171 : tensor -> tensor<2xindex> + %8180 = shape.shape_of %8178 : tensor -> tensor<2xindex> + %8181 = shape.cstr_broadcastable %8179, %8180 : tensor<2xindex>, tensor<2xindex> + %8182 = shape.assuming %8181 -> (tensor) { + %19688 = shape.broadcast %8179, %8180 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8171, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8178, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8183 = shape.shape_of %8182 : tensor -> tensor<2xindex> + %8184 = stablehlo.dynamic_broadcast_in_dim %8182, %8183, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8185 = stablehlo.dynamic_broadcast_in_dim %213, %8183, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8186 = stablehlo.multiply %8184, %8185 : tensor + %dim_2853 = tensor.dim %8145, %c0 : tensor + %8187 = arith.index_cast %dim_2853 : index to i64 + %dim_2854 = tensor.dim %8182, %c0 : tensor + %8188 = arith.index_cast %dim_2854 : index to i64 + %8189 = arith.maxsi %8187, %8188 : i64 + %8190 = arith.index_cast %8189 : i64 to index + %from_elements_2855 = tensor.from_elements %8190, %c4096 : tensor<2xindex> + %8191 = stablehlo.dynamic_broadcast_in_dim %8145, %from_elements_2855, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2856 = tensor.dim %8191, %c0 : tensor + %8192 = arith.index_cast %dim_2856 : index to i64 + %from_elements_2857 = tensor.from_elements %8192, %c4096_i64 : tensor<2xi64> + %8193 = stablehlo.real_dynamic_slice %8186, %c_22, %from_elements_2857, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2858 = tensor.from_elements %8192, %c4096_i64, %c1_i64 : tensor<3xi64> + %8194 = stablehlo.dynamic_reshape %8191, %from_elements_2858 : (tensor, tensor<3xi64>) -> tensor + %8195 = stablehlo.dynamic_iota %from_elements_2858, dim = 1 : (tensor<3xi64>) -> tensor + %8196 = stablehlo.concatenate %8194, %8195, dim = 2 : (tensor, tensor) -> tensor + %8197 = "stablehlo.scatter"(%cst_2, %8196, %8193) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8198 = stablehlo.slice %8132 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8199 = stablehlo.reshape %8198 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8200 = stablehlo.custom_call @byteir.non_zero(%8199) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2859 = tensor.dim %8200, %c0 : tensor + %8201 = arith.index_cast %dim_2859 : index to i64 + %from_elements_2860 = tensor.from_elements %8201, %c1_i64 : tensor<2xi64> + %8202 = stablehlo.real_dynamic_slice %8200, %c_22, %from_elements_2860, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2861 = tensor.dim %8202, %c0 : tensor + %8203 = arith.index_cast %dim_2861 : index to i64 + %from_elements_2862 = tensor.from_elements %8203 : tensor<1xi64> + %8204 = stablehlo.dynamic_reshape %8202, %from_elements_2862 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2863 = tensor.from_elements %8201, %c2_i64 : tensor<2xi64> + %8205 = stablehlo.real_dynamic_slice %8200, %c_24, %from_elements_2863, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2864 = tensor.dim %8205, %c0 : tensor + %8206 = arith.index_cast %dim_2864 : index to i64 + %from_elements_2865 = tensor.from_elements %8206 : tensor<1xi64> + %8207 = stablehlo.dynamic_reshape %8205, %from_elements_2865 : (tensor, tensor<1xi64>) -> tensor + %dim_2866 = tensor.dim %8207, %c0 : tensor + %8208 = arith.index_cast %dim_2866 : index to i64 + %from_elements_2867 = tensor.from_elements %8208, %c1_i64 : tensor<2xi64> + %8209 = stablehlo.dynamic_reshape %8207, %from_elements_2867 : (tensor, tensor<2xi64>) -> tensor + %dim_2868 = tensor.dim %8209, %c0 : tensor + %8210 = arith.index_cast %dim_2868 : index to i64 + %from_elements_2869 = tensor.from_elements %c1_i64, %8210, %c4096_i64 : tensor<3xi64> + %8211 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2869, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2870 = tensor.dim %8211, %c1 : tensor<1x?x4096xi64> + %8212 = arith.index_cast %dim_2870 : index to i64 + %from_elements_2871 = tensor.from_elements %c1_i64, %8212, %c4096_i64, %c1_i64 : tensor<4xi64> + %8213 = stablehlo.dynamic_reshape %8211, %from_elements_2871 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8214 = stablehlo.dynamic_broadcast_in_dim %8209, %from_elements_2869, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2872 = tensor.dim %8214, %c1 : tensor<1x?x4096xi64> + %8215 = arith.index_cast %dim_2872 : index to i64 + %from_elements_2873 = tensor.from_elements %c1_i64, %8215, %c4096_i64, %c1_i64 : tensor<4xi64> + %8216 = stablehlo.dynamic_reshape %8214, %from_elements_2873 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8217 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2869, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2874 = tensor.dim %8217, %c1 : tensor<1x?x4096xi64> + %8218 = arith.index_cast %dim_2874 : index to i64 + %from_elements_2875 = tensor.from_elements %c1_i64, %8218, %c4096_i64, %c1_i64 : tensor<4xi64> + %8219 = stablehlo.dynamic_reshape %8217, %from_elements_2875 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8220 = stablehlo.concatenate %8213, %8216, %8219, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8221 = "stablehlo.gather"(%8143, %8220) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8222 = shape.shape_of %8221 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8223 = shape.num_elements %8222 : tensor<3xindex> -> index + %8224 = stablehlo.compute_reshape_shape %8223, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8225 = stablehlo.dynamic_reshape %8221, %8224 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8226 = stablehlo.dot %8225, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8227 = stablehlo.logistic %8226 : tensor + %8228 = shape.shape_of %8227 : tensor -> tensor<2xindex> + %8229 = shape.shape_of %8226 : tensor -> tensor<2xindex> + %8230 = shape.cstr_broadcastable %8228, %8229 : tensor<2xindex>, tensor<2xindex> + %8231 = shape.assuming %8230 -> (tensor) { + %19688 = shape.broadcast %8228, %8229 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8227, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8226, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8232 = shape.shape_of %8231 : tensor -> tensor<2xindex> + %8233 = shape.cstr_broadcastable %8232, %8229 : tensor<2xindex>, tensor<2xindex> + %8234 = shape.assuming %8233 -> (tensor) { + %19688 = shape.broadcast %8232, %8229 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8231, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8226, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8235 = stablehlo.dot %8234, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2876 = tensor.dim %8207, %c0 : tensor + %8236 = arith.index_cast %dim_2876 : index to i64 + %from_elements_2877 = tensor.from_elements %8236, %c1_i64 : tensor<2xi64> + %8237 = stablehlo.dynamic_reshape %8207, %from_elements_2877 : (tensor, tensor<2xi64>) -> tensor + %dim_2878 = tensor.dim %8204, %c0 : tensor + %8238 = arith.index_cast %dim_2878 : index to i64 + %from_elements_2879 = tensor.from_elements %8238, %c1_i64 : tensor<2xi64> + %8239 = stablehlo.dynamic_reshape %8204, %from_elements_2879 : (tensor, tensor<2xi64>) -> tensor + %8240 = stablehlo.concatenate %8237, %8239, dim = 1 : (tensor, tensor) -> tensor + %8241 = "stablehlo.gather"(%8172, %8240) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8242 = shape.shape_of %8235 : tensor -> tensor<2xindex> + %8243 = shape.shape_of %8241 : tensor -> tensor<2xindex> + %8244 = shape.cstr_broadcastable %8242, %8243 : tensor<2xindex>, tensor<2xindex> + %8245 = shape.assuming %8244 -> (tensor) { + %19688 = shape.broadcast %8242, %8243 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8235, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8241, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8246 = shape.shape_of %8245 : tensor -> tensor<2xindex> + %8247 = stablehlo.dynamic_broadcast_in_dim %8245, %8246, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8248 = stablehlo.dynamic_broadcast_in_dim %213, %8246, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8249 = stablehlo.multiply %8247, %8248 : tensor + %dim_2880 = tensor.dim %8209, %c0 : tensor + %8250 = arith.index_cast %dim_2880 : index to i64 + %dim_2881 = tensor.dim %8245, %c0 : tensor + %8251 = arith.index_cast %dim_2881 : index to i64 + %8252 = arith.maxsi %8250, %8251 : i64 + %8253 = arith.index_cast %8252 : i64 to index + %from_elements_2882 = tensor.from_elements %8253, %c4096 : tensor<2xindex> + %8254 = stablehlo.dynamic_broadcast_in_dim %8209, %from_elements_2882, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2883 = tensor.dim %8254, %c0 : tensor + %8255 = arith.index_cast %dim_2883 : index to i64 + %from_elements_2884 = tensor.from_elements %8255, %c4096_i64 : tensor<2xi64> + %8256 = stablehlo.real_dynamic_slice %8249, %c_22, %from_elements_2884, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2885 = tensor.from_elements %8255, %c4096_i64, %c1_i64 : tensor<3xi64> + %8257 = stablehlo.dynamic_reshape %8254, %from_elements_2885 : (tensor, tensor<3xi64>) -> tensor + %8258 = stablehlo.dynamic_iota %from_elements_2885, dim = 1 : (tensor<3xi64>) -> tensor + %8259 = stablehlo.concatenate %8257, %8258, dim = 2 : (tensor, tensor) -> tensor + %8260 = "stablehlo.scatter"(%8197, %8259, %8256) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8261 = stablehlo.slice %8132 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8262 = stablehlo.reshape %8261 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8263 = stablehlo.custom_call @byteir.non_zero(%8262) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2886 = tensor.dim %8263, %c0 : tensor + %8264 = arith.index_cast %dim_2886 : index to i64 + %from_elements_2887 = tensor.from_elements %8264, %c1_i64 : tensor<2xi64> + %8265 = stablehlo.real_dynamic_slice %8263, %c_22, %from_elements_2887, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2888 = tensor.dim %8265, %c0 : tensor + %8266 = arith.index_cast %dim_2888 : index to i64 + %from_elements_2889 = tensor.from_elements %8266 : tensor<1xi64> + %8267 = stablehlo.dynamic_reshape %8265, %from_elements_2889 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2890 = tensor.from_elements %8264, %c2_i64 : tensor<2xi64> + %8268 = stablehlo.real_dynamic_slice %8263, %c_24, %from_elements_2890, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2891 = tensor.dim %8268, %c0 : tensor + %8269 = arith.index_cast %dim_2891 : index to i64 + %from_elements_2892 = tensor.from_elements %8269 : tensor<1xi64> + %8270 = stablehlo.dynamic_reshape %8268, %from_elements_2892 : (tensor, tensor<1xi64>) -> tensor + %dim_2893 = tensor.dim %8270, %c0 : tensor + %8271 = arith.index_cast %dim_2893 : index to i64 + %from_elements_2894 = tensor.from_elements %8271, %c1_i64 : tensor<2xi64> + %8272 = stablehlo.dynamic_reshape %8270, %from_elements_2894 : (tensor, tensor<2xi64>) -> tensor + %dim_2895 = tensor.dim %8272, %c0 : tensor + %8273 = arith.index_cast %dim_2895 : index to i64 + %from_elements_2896 = tensor.from_elements %c1_i64, %8273, %c4096_i64 : tensor<3xi64> + %8274 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2896, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2897 = tensor.dim %8274, %c1 : tensor<1x?x4096xi64> + %8275 = arith.index_cast %dim_2897 : index to i64 + %from_elements_2898 = tensor.from_elements %c1_i64, %8275, %c4096_i64, %c1_i64 : tensor<4xi64> + %8276 = stablehlo.dynamic_reshape %8274, %from_elements_2898 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8277 = stablehlo.dynamic_broadcast_in_dim %8272, %from_elements_2896, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2899 = tensor.dim %8277, %c1 : tensor<1x?x4096xi64> + %8278 = arith.index_cast %dim_2899 : index to i64 + %from_elements_2900 = tensor.from_elements %c1_i64, %8278, %c4096_i64, %c1_i64 : tensor<4xi64> + %8279 = stablehlo.dynamic_reshape %8277, %from_elements_2900 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8280 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2896, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2901 = tensor.dim %8280, %c1 : tensor<1x?x4096xi64> + %8281 = arith.index_cast %dim_2901 : index to i64 + %from_elements_2902 = tensor.from_elements %c1_i64, %8281, %c4096_i64, %c1_i64 : tensor<4xi64> + %8282 = stablehlo.dynamic_reshape %8280, %from_elements_2902 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8283 = stablehlo.concatenate %8276, %8279, %8282, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8284 = "stablehlo.gather"(%8143, %8283) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8285 = shape.shape_of %8284 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8286 = shape.num_elements %8285 : tensor<3xindex> -> index + %8287 = stablehlo.compute_reshape_shape %8286, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8288 = stablehlo.dynamic_reshape %8284, %8287 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8289 = stablehlo.dot %8288, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8290 = stablehlo.logistic %8289 : tensor + %8291 = shape.shape_of %8290 : tensor -> tensor<2xindex> + %8292 = shape.shape_of %8289 : tensor -> tensor<2xindex> + %8293 = shape.cstr_broadcastable %8291, %8292 : tensor<2xindex>, tensor<2xindex> + %8294 = shape.assuming %8293 -> (tensor) { + %19688 = shape.broadcast %8291, %8292 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8290, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8289, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8295 = shape.shape_of %8294 : tensor -> tensor<2xindex> + %8296 = shape.cstr_broadcastable %8295, %8292 : tensor<2xindex>, tensor<2xindex> + %8297 = shape.assuming %8296 -> (tensor) { + %19688 = shape.broadcast %8295, %8292 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8294, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8289, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8298 = stablehlo.dot %8297, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2903 = tensor.dim %8270, %c0 : tensor + %8299 = arith.index_cast %dim_2903 : index to i64 + %from_elements_2904 = tensor.from_elements %8299, %c1_i64 : tensor<2xi64> + %8300 = stablehlo.dynamic_reshape %8270, %from_elements_2904 : (tensor, tensor<2xi64>) -> tensor + %dim_2905 = tensor.dim %8267, %c0 : tensor + %8301 = arith.index_cast %dim_2905 : index to i64 + %from_elements_2906 = tensor.from_elements %8301, %c1_i64 : tensor<2xi64> + %8302 = stablehlo.dynamic_reshape %8267, %from_elements_2906 : (tensor, tensor<2xi64>) -> tensor + %8303 = stablehlo.concatenate %8300, %8302, dim = 1 : (tensor, tensor) -> tensor + %8304 = "stablehlo.gather"(%8172, %8303) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8305 = shape.shape_of %8298 : tensor -> tensor<2xindex> + %8306 = shape.shape_of %8304 : tensor -> tensor<2xindex> + %8307 = shape.cstr_broadcastable %8305, %8306 : tensor<2xindex>, tensor<2xindex> + %8308 = shape.assuming %8307 -> (tensor) { + %19688 = shape.broadcast %8305, %8306 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8298, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8304, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8309 = shape.shape_of %8308 : tensor -> tensor<2xindex> + %8310 = stablehlo.dynamic_broadcast_in_dim %8308, %8309, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8311 = stablehlo.dynamic_broadcast_in_dim %213, %8309, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8312 = stablehlo.multiply %8310, %8311 : tensor + %dim_2907 = tensor.dim %8272, %c0 : tensor + %8313 = arith.index_cast %dim_2907 : index to i64 + %dim_2908 = tensor.dim %8308, %c0 : tensor + %8314 = arith.index_cast %dim_2908 : index to i64 + %8315 = arith.maxsi %8313, %8314 : i64 + %8316 = arith.index_cast %8315 : i64 to index + %from_elements_2909 = tensor.from_elements %8316, %c4096 : tensor<2xindex> + %8317 = stablehlo.dynamic_broadcast_in_dim %8272, %from_elements_2909, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2910 = tensor.dim %8317, %c0 : tensor + %8318 = arith.index_cast %dim_2910 : index to i64 + %from_elements_2911 = tensor.from_elements %8318, %c4096_i64 : tensor<2xi64> + %8319 = stablehlo.real_dynamic_slice %8312, %c_22, %from_elements_2911, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2912 = tensor.from_elements %8318, %c4096_i64, %c1_i64 : tensor<3xi64> + %8320 = stablehlo.dynamic_reshape %8317, %from_elements_2912 : (tensor, tensor<3xi64>) -> tensor + %8321 = stablehlo.dynamic_iota %from_elements_2912, dim = 1 : (tensor<3xi64>) -> tensor + %8322 = stablehlo.concatenate %8320, %8321, dim = 2 : (tensor, tensor) -> tensor + %8323 = "stablehlo.scatter"(%8260, %8322, %8319) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8324 = stablehlo.slice %8132 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8325 = stablehlo.reshape %8324 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8326 = stablehlo.custom_call @byteir.non_zero(%8325) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2913 = tensor.dim %8326, %c0 : tensor + %8327 = arith.index_cast %dim_2913 : index to i64 + %from_elements_2914 = tensor.from_elements %8327, %c1_i64 : tensor<2xi64> + %8328 = stablehlo.real_dynamic_slice %8326, %c_22, %from_elements_2914, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2915 = tensor.dim %8328, %c0 : tensor + %8329 = arith.index_cast %dim_2915 : index to i64 + %from_elements_2916 = tensor.from_elements %8329 : tensor<1xi64> + %8330 = stablehlo.dynamic_reshape %8328, %from_elements_2916 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2917 = tensor.from_elements %8327, %c2_i64 : tensor<2xi64> + %8331 = stablehlo.real_dynamic_slice %8326, %c_24, %from_elements_2917, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2918 = tensor.dim %8331, %c0 : tensor + %8332 = arith.index_cast %dim_2918 : index to i64 + %from_elements_2919 = tensor.from_elements %8332 : tensor<1xi64> + %8333 = stablehlo.dynamic_reshape %8331, %from_elements_2919 : (tensor, tensor<1xi64>) -> tensor + %dim_2920 = tensor.dim %8333, %c0 : tensor + %8334 = arith.index_cast %dim_2920 : index to i64 + %from_elements_2921 = tensor.from_elements %8334, %c1_i64 : tensor<2xi64> + %8335 = stablehlo.dynamic_reshape %8333, %from_elements_2921 : (tensor, tensor<2xi64>) -> tensor + %dim_2922 = tensor.dim %8335, %c0 : tensor + %8336 = arith.index_cast %dim_2922 : index to i64 + %from_elements_2923 = tensor.from_elements %c1_i64, %8336, %c4096_i64 : tensor<3xi64> + %8337 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2923, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2924 = tensor.dim %8337, %c1 : tensor<1x?x4096xi64> + %8338 = arith.index_cast %dim_2924 : index to i64 + %from_elements_2925 = tensor.from_elements %c1_i64, %8338, %c4096_i64, %c1_i64 : tensor<4xi64> + %8339 = stablehlo.dynamic_reshape %8337, %from_elements_2925 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8340 = stablehlo.dynamic_broadcast_in_dim %8335, %from_elements_2923, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2926 = tensor.dim %8340, %c1 : tensor<1x?x4096xi64> + %8341 = arith.index_cast %dim_2926 : index to i64 + %from_elements_2927 = tensor.from_elements %c1_i64, %8341, %c4096_i64, %c1_i64 : tensor<4xi64> + %8342 = stablehlo.dynamic_reshape %8340, %from_elements_2927 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8343 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2923, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2928 = tensor.dim %8343, %c1 : tensor<1x?x4096xi64> + %8344 = arith.index_cast %dim_2928 : index to i64 + %from_elements_2929 = tensor.from_elements %c1_i64, %8344, %c4096_i64, %c1_i64 : tensor<4xi64> + %8345 = stablehlo.dynamic_reshape %8343, %from_elements_2929 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8346 = stablehlo.concatenate %8339, %8342, %8345, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8347 = "stablehlo.gather"(%8143, %8346) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8348 = shape.shape_of %8347 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8349 = shape.num_elements %8348 : tensor<3xindex> -> index + %8350 = stablehlo.compute_reshape_shape %8349, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8351 = stablehlo.dynamic_reshape %8347, %8350 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8352 = stablehlo.dot %8351, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8353 = stablehlo.logistic %8352 : tensor + %8354 = shape.shape_of %8353 : tensor -> tensor<2xindex> + %8355 = shape.shape_of %8352 : tensor -> tensor<2xindex> + %8356 = shape.cstr_broadcastable %8354, %8355 : tensor<2xindex>, tensor<2xindex> + %8357 = shape.assuming %8356 -> (tensor) { + %19688 = shape.broadcast %8354, %8355 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8353, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8352, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8358 = shape.shape_of %8357 : tensor -> tensor<2xindex> + %8359 = shape.cstr_broadcastable %8358, %8355 : tensor<2xindex>, tensor<2xindex> + %8360 = shape.assuming %8359 -> (tensor) { + %19688 = shape.broadcast %8358, %8355 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8357, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8352, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8361 = stablehlo.dot %8360, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2930 = tensor.dim %8333, %c0 : tensor + %8362 = arith.index_cast %dim_2930 : index to i64 + %from_elements_2931 = tensor.from_elements %8362, %c1_i64 : tensor<2xi64> + %8363 = stablehlo.dynamic_reshape %8333, %from_elements_2931 : (tensor, tensor<2xi64>) -> tensor + %dim_2932 = tensor.dim %8330, %c0 : tensor + %8364 = arith.index_cast %dim_2932 : index to i64 + %from_elements_2933 = tensor.from_elements %8364, %c1_i64 : tensor<2xi64> + %8365 = stablehlo.dynamic_reshape %8330, %from_elements_2933 : (tensor, tensor<2xi64>) -> tensor + %8366 = stablehlo.concatenate %8363, %8365, dim = 1 : (tensor, tensor) -> tensor + %8367 = "stablehlo.gather"(%8172, %8366) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8368 = shape.shape_of %8361 : tensor -> tensor<2xindex> + %8369 = shape.shape_of %8367 : tensor -> tensor<2xindex> + %8370 = shape.cstr_broadcastable %8368, %8369 : tensor<2xindex>, tensor<2xindex> + %8371 = shape.assuming %8370 -> (tensor) { + %19688 = shape.broadcast %8368, %8369 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8361, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8367, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8372 = shape.shape_of %8371 : tensor -> tensor<2xindex> + %8373 = stablehlo.dynamic_broadcast_in_dim %8371, %8372, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8374 = stablehlo.dynamic_broadcast_in_dim %213, %8372, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8375 = stablehlo.multiply %8373, %8374 : tensor + %dim_2934 = tensor.dim %8335, %c0 : tensor + %8376 = arith.index_cast %dim_2934 : index to i64 + %dim_2935 = tensor.dim %8371, %c0 : tensor + %8377 = arith.index_cast %dim_2935 : index to i64 + %8378 = arith.maxsi %8376, %8377 : i64 + %8379 = arith.index_cast %8378 : i64 to index + %from_elements_2936 = tensor.from_elements %8379, %c4096 : tensor<2xindex> + %8380 = stablehlo.dynamic_broadcast_in_dim %8335, %from_elements_2936, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2937 = tensor.dim %8380, %c0 : tensor + %8381 = arith.index_cast %dim_2937 : index to i64 + %from_elements_2938 = tensor.from_elements %8381, %c4096_i64 : tensor<2xi64> + %8382 = stablehlo.real_dynamic_slice %8375, %c_22, %from_elements_2938, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2939 = tensor.from_elements %8381, %c4096_i64, %c1_i64 : tensor<3xi64> + %8383 = stablehlo.dynamic_reshape %8380, %from_elements_2939 : (tensor, tensor<3xi64>) -> tensor + %8384 = stablehlo.dynamic_iota %from_elements_2939, dim = 1 : (tensor<3xi64>) -> tensor + %8385 = stablehlo.concatenate %8383, %8384, dim = 2 : (tensor, tensor) -> tensor + %8386 = "stablehlo.scatter"(%8323, %8385, %8382) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8387 = stablehlo.slice %8132 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8388 = stablehlo.reshape %8387 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8389 = stablehlo.custom_call @byteir.non_zero(%8388) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2940 = tensor.dim %8389, %c0 : tensor + %8390 = arith.index_cast %dim_2940 : index to i64 + %from_elements_2941 = tensor.from_elements %8390, %c1_i64 : tensor<2xi64> + %8391 = stablehlo.real_dynamic_slice %8389, %c_22, %from_elements_2941, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2942 = tensor.dim %8391, %c0 : tensor + %8392 = arith.index_cast %dim_2942 : index to i64 + %from_elements_2943 = tensor.from_elements %8392 : tensor<1xi64> + %8393 = stablehlo.dynamic_reshape %8391, %from_elements_2943 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2944 = tensor.from_elements %8390, %c2_i64 : tensor<2xi64> + %8394 = stablehlo.real_dynamic_slice %8389, %c_24, %from_elements_2944, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2945 = tensor.dim %8394, %c0 : tensor + %8395 = arith.index_cast %dim_2945 : index to i64 + %from_elements_2946 = tensor.from_elements %8395 : tensor<1xi64> + %8396 = stablehlo.dynamic_reshape %8394, %from_elements_2946 : (tensor, tensor<1xi64>) -> tensor + %dim_2947 = tensor.dim %8396, %c0 : tensor + %8397 = arith.index_cast %dim_2947 : index to i64 + %from_elements_2948 = tensor.from_elements %8397, %c1_i64 : tensor<2xi64> + %8398 = stablehlo.dynamic_reshape %8396, %from_elements_2948 : (tensor, tensor<2xi64>) -> tensor + %dim_2949 = tensor.dim %8398, %c0 : tensor + %8399 = arith.index_cast %dim_2949 : index to i64 + %from_elements_2950 = tensor.from_elements %c1_i64, %8399, %c4096_i64 : tensor<3xi64> + %8400 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2950, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2951 = tensor.dim %8400, %c1 : tensor<1x?x4096xi64> + %8401 = arith.index_cast %dim_2951 : index to i64 + %from_elements_2952 = tensor.from_elements %c1_i64, %8401, %c4096_i64, %c1_i64 : tensor<4xi64> + %8402 = stablehlo.dynamic_reshape %8400, %from_elements_2952 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8403 = stablehlo.dynamic_broadcast_in_dim %8398, %from_elements_2950, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2953 = tensor.dim %8403, %c1 : tensor<1x?x4096xi64> + %8404 = arith.index_cast %dim_2953 : index to i64 + %from_elements_2954 = tensor.from_elements %c1_i64, %8404, %c4096_i64, %c1_i64 : tensor<4xi64> + %8405 = stablehlo.dynamic_reshape %8403, %from_elements_2954 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8406 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2950, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2955 = tensor.dim %8406, %c1 : tensor<1x?x4096xi64> + %8407 = arith.index_cast %dim_2955 : index to i64 + %from_elements_2956 = tensor.from_elements %c1_i64, %8407, %c4096_i64, %c1_i64 : tensor<4xi64> + %8408 = stablehlo.dynamic_reshape %8406, %from_elements_2956 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8409 = stablehlo.concatenate %8402, %8405, %8408, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8410 = "stablehlo.gather"(%8143, %8409) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8411 = shape.shape_of %8410 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8412 = shape.num_elements %8411 : tensor<3xindex> -> index + %8413 = stablehlo.compute_reshape_shape %8412, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8414 = stablehlo.dynamic_reshape %8410, %8413 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8415 = stablehlo.dot %8414, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8416 = stablehlo.logistic %8415 : tensor + %8417 = shape.shape_of %8416 : tensor -> tensor<2xindex> + %8418 = shape.shape_of %8415 : tensor -> tensor<2xindex> + %8419 = shape.cstr_broadcastable %8417, %8418 : tensor<2xindex>, tensor<2xindex> + %8420 = shape.assuming %8419 -> (tensor) { + %19688 = shape.broadcast %8417, %8418 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8416, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8415, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8421 = shape.shape_of %8420 : tensor -> tensor<2xindex> + %8422 = shape.cstr_broadcastable %8421, %8418 : tensor<2xindex>, tensor<2xindex> + %8423 = shape.assuming %8422 -> (tensor) { + %19688 = shape.broadcast %8421, %8418 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8420, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8415, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8424 = stablehlo.dot %8423, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2957 = tensor.dim %8396, %c0 : tensor + %8425 = arith.index_cast %dim_2957 : index to i64 + %from_elements_2958 = tensor.from_elements %8425, %c1_i64 : tensor<2xi64> + %8426 = stablehlo.dynamic_reshape %8396, %from_elements_2958 : (tensor, tensor<2xi64>) -> tensor + %dim_2959 = tensor.dim %8393, %c0 : tensor + %8427 = arith.index_cast %dim_2959 : index to i64 + %from_elements_2960 = tensor.from_elements %8427, %c1_i64 : tensor<2xi64> + %8428 = stablehlo.dynamic_reshape %8393, %from_elements_2960 : (tensor, tensor<2xi64>) -> tensor + %8429 = stablehlo.concatenate %8426, %8428, dim = 1 : (tensor, tensor) -> tensor + %8430 = "stablehlo.gather"(%8172, %8429) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8431 = shape.shape_of %8424 : tensor -> tensor<2xindex> + %8432 = shape.shape_of %8430 : tensor -> tensor<2xindex> + %8433 = shape.cstr_broadcastable %8431, %8432 : tensor<2xindex>, tensor<2xindex> + %8434 = shape.assuming %8433 -> (tensor) { + %19688 = shape.broadcast %8431, %8432 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8424, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8430, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8435 = shape.shape_of %8434 : tensor -> tensor<2xindex> + %8436 = stablehlo.dynamic_broadcast_in_dim %8434, %8435, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8437 = stablehlo.dynamic_broadcast_in_dim %213, %8435, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8438 = stablehlo.multiply %8436, %8437 : tensor + %dim_2961 = tensor.dim %8398, %c0 : tensor + %8439 = arith.index_cast %dim_2961 : index to i64 + %dim_2962 = tensor.dim %8434, %c0 : tensor + %8440 = arith.index_cast %dim_2962 : index to i64 + %8441 = arith.maxsi %8439, %8440 : i64 + %8442 = arith.index_cast %8441 : i64 to index + %from_elements_2963 = tensor.from_elements %8442, %c4096 : tensor<2xindex> + %8443 = stablehlo.dynamic_broadcast_in_dim %8398, %from_elements_2963, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2964 = tensor.dim %8443, %c0 : tensor + %8444 = arith.index_cast %dim_2964 : index to i64 + %from_elements_2965 = tensor.from_elements %8444, %c4096_i64 : tensor<2xi64> + %8445 = stablehlo.real_dynamic_slice %8438, %c_22, %from_elements_2965, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2966 = tensor.from_elements %8444, %c4096_i64, %c1_i64 : tensor<3xi64> + %8446 = stablehlo.dynamic_reshape %8443, %from_elements_2966 : (tensor, tensor<3xi64>) -> tensor + %8447 = stablehlo.dynamic_iota %from_elements_2966, dim = 1 : (tensor<3xi64>) -> tensor + %8448 = stablehlo.concatenate %8446, %8447, dim = 2 : (tensor, tensor) -> tensor + %8449 = "stablehlo.scatter"(%8386, %8448, %8445) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8450 = stablehlo.slice %8132 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8451 = stablehlo.reshape %8450 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8452 = stablehlo.custom_call @byteir.non_zero(%8451) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2967 = tensor.dim %8452, %c0 : tensor + %8453 = arith.index_cast %dim_2967 : index to i64 + %from_elements_2968 = tensor.from_elements %8453, %c1_i64 : tensor<2xi64> + %8454 = stablehlo.real_dynamic_slice %8452, %c_22, %from_elements_2968, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2969 = tensor.dim %8454, %c0 : tensor + %8455 = arith.index_cast %dim_2969 : index to i64 + %from_elements_2970 = tensor.from_elements %8455 : tensor<1xi64> + %8456 = stablehlo.dynamic_reshape %8454, %from_elements_2970 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2971 = tensor.from_elements %8453, %c2_i64 : tensor<2xi64> + %8457 = stablehlo.real_dynamic_slice %8452, %c_24, %from_elements_2971, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2972 = tensor.dim %8457, %c0 : tensor + %8458 = arith.index_cast %dim_2972 : index to i64 + %from_elements_2973 = tensor.from_elements %8458 : tensor<1xi64> + %8459 = stablehlo.dynamic_reshape %8457, %from_elements_2973 : (tensor, tensor<1xi64>) -> tensor + %dim_2974 = tensor.dim %8459, %c0 : tensor + %8460 = arith.index_cast %dim_2974 : index to i64 + %from_elements_2975 = tensor.from_elements %8460, %c1_i64 : tensor<2xi64> + %8461 = stablehlo.dynamic_reshape %8459, %from_elements_2975 : (tensor, tensor<2xi64>) -> tensor + %dim_2976 = tensor.dim %8461, %c0 : tensor + %8462 = arith.index_cast %dim_2976 : index to i64 + %from_elements_2977 = tensor.from_elements %c1_i64, %8462, %c4096_i64 : tensor<3xi64> + %8463 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_2977, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2978 = tensor.dim %8463, %c1 : tensor<1x?x4096xi64> + %8464 = arith.index_cast %dim_2978 : index to i64 + %from_elements_2979 = tensor.from_elements %c1_i64, %8464, %c4096_i64, %c1_i64 : tensor<4xi64> + %8465 = stablehlo.dynamic_reshape %8463, %from_elements_2979 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8466 = stablehlo.dynamic_broadcast_in_dim %8461, %from_elements_2977, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2980 = tensor.dim %8466, %c1 : tensor<1x?x4096xi64> + %8467 = arith.index_cast %dim_2980 : index to i64 + %from_elements_2981 = tensor.from_elements %c1_i64, %8467, %c4096_i64, %c1_i64 : tensor<4xi64> + %8468 = stablehlo.dynamic_reshape %8466, %from_elements_2981 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8469 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_2977, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_2982 = tensor.dim %8469, %c1 : tensor<1x?x4096xi64> + %8470 = arith.index_cast %dim_2982 : index to i64 + %from_elements_2983 = tensor.from_elements %c1_i64, %8470, %c4096_i64, %c1_i64 : tensor<4xi64> + %8471 = stablehlo.dynamic_reshape %8469, %from_elements_2983 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8472 = stablehlo.concatenate %8465, %8468, %8471, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8473 = "stablehlo.gather"(%8143, %8472) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8474 = shape.shape_of %8473 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8475 = shape.num_elements %8474 : tensor<3xindex> -> index + %8476 = stablehlo.compute_reshape_shape %8475, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8477 = stablehlo.dynamic_reshape %8473, %8476 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8478 = stablehlo.dot %8477, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8479 = stablehlo.logistic %8478 : tensor + %8480 = shape.shape_of %8479 : tensor -> tensor<2xindex> + %8481 = shape.shape_of %8478 : tensor -> tensor<2xindex> + %8482 = shape.cstr_broadcastable %8480, %8481 : tensor<2xindex>, tensor<2xindex> + %8483 = shape.assuming %8482 -> (tensor) { + %19688 = shape.broadcast %8480, %8481 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8479, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8478, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8484 = shape.shape_of %8483 : tensor -> tensor<2xindex> + %8485 = shape.cstr_broadcastable %8484, %8481 : tensor<2xindex>, tensor<2xindex> + %8486 = shape.assuming %8485 -> (tensor) { + %19688 = shape.broadcast %8484, %8481 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8483, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8478, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8487 = stablehlo.dot %8486, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_2984 = tensor.dim %8459, %c0 : tensor + %8488 = arith.index_cast %dim_2984 : index to i64 + %from_elements_2985 = tensor.from_elements %8488, %c1_i64 : tensor<2xi64> + %8489 = stablehlo.dynamic_reshape %8459, %from_elements_2985 : (tensor, tensor<2xi64>) -> tensor + %dim_2986 = tensor.dim %8456, %c0 : tensor + %8490 = arith.index_cast %dim_2986 : index to i64 + %from_elements_2987 = tensor.from_elements %8490, %c1_i64 : tensor<2xi64> + %8491 = stablehlo.dynamic_reshape %8456, %from_elements_2987 : (tensor, tensor<2xi64>) -> tensor + %8492 = stablehlo.concatenate %8489, %8491, dim = 1 : (tensor, tensor) -> tensor + %8493 = "stablehlo.gather"(%8172, %8492) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8494 = shape.shape_of %8487 : tensor -> tensor<2xindex> + %8495 = shape.shape_of %8493 : tensor -> tensor<2xindex> + %8496 = shape.cstr_broadcastable %8494, %8495 : tensor<2xindex>, tensor<2xindex> + %8497 = shape.assuming %8496 -> (tensor) { + %19688 = shape.broadcast %8494, %8495 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8487, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8493, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8498 = shape.shape_of %8497 : tensor -> tensor<2xindex> + %8499 = stablehlo.dynamic_broadcast_in_dim %8497, %8498, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8500 = stablehlo.dynamic_broadcast_in_dim %213, %8498, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8501 = stablehlo.multiply %8499, %8500 : tensor + %dim_2988 = tensor.dim %8461, %c0 : tensor + %8502 = arith.index_cast %dim_2988 : index to i64 + %dim_2989 = tensor.dim %8497, %c0 : tensor + %8503 = arith.index_cast %dim_2989 : index to i64 + %8504 = arith.maxsi %8502, %8503 : i64 + %8505 = arith.index_cast %8504 : i64 to index + %from_elements_2990 = tensor.from_elements %8505, %c4096 : tensor<2xindex> + %8506 = stablehlo.dynamic_broadcast_in_dim %8461, %from_elements_2990, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_2991 = tensor.dim %8506, %c0 : tensor + %8507 = arith.index_cast %dim_2991 : index to i64 + %from_elements_2992 = tensor.from_elements %8507, %c4096_i64 : tensor<2xi64> + %8508 = stablehlo.real_dynamic_slice %8501, %c_22, %from_elements_2992, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_2993 = tensor.from_elements %8507, %c4096_i64, %c1_i64 : tensor<3xi64> + %8509 = stablehlo.dynamic_reshape %8506, %from_elements_2993 : (tensor, tensor<3xi64>) -> tensor + %8510 = stablehlo.dynamic_iota %from_elements_2993, dim = 1 : (tensor<3xi64>) -> tensor + %8511 = stablehlo.concatenate %8509, %8510, dim = 2 : (tensor, tensor) -> tensor + %8512 = "stablehlo.scatter"(%8449, %8511, %8508) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8513 = stablehlo.slice %8132 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8514 = stablehlo.reshape %8513 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8515 = stablehlo.custom_call @byteir.non_zero(%8514) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_2994 = tensor.dim %8515, %c0 : tensor + %8516 = arith.index_cast %dim_2994 : index to i64 + %from_elements_2995 = tensor.from_elements %8516, %c1_i64 : tensor<2xi64> + %8517 = stablehlo.real_dynamic_slice %8515, %c_22, %from_elements_2995, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2996 = tensor.dim %8517, %c0 : tensor + %8518 = arith.index_cast %dim_2996 : index to i64 + %from_elements_2997 = tensor.from_elements %8518 : tensor<1xi64> + %8519 = stablehlo.dynamic_reshape %8517, %from_elements_2997 : (tensor, tensor<1xi64>) -> tensor + %from_elements_2998 = tensor.from_elements %8516, %c2_i64 : tensor<2xi64> + %8520 = stablehlo.real_dynamic_slice %8515, %c_24, %from_elements_2998, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_2999 = tensor.dim %8520, %c0 : tensor + %8521 = arith.index_cast %dim_2999 : index to i64 + %from_elements_3000 = tensor.from_elements %8521 : tensor<1xi64> + %8522 = stablehlo.dynamic_reshape %8520, %from_elements_3000 : (tensor, tensor<1xi64>) -> tensor + %dim_3001 = tensor.dim %8522, %c0 : tensor + %8523 = arith.index_cast %dim_3001 : index to i64 + %from_elements_3002 = tensor.from_elements %8523, %c1_i64 : tensor<2xi64> + %8524 = stablehlo.dynamic_reshape %8522, %from_elements_3002 : (tensor, tensor<2xi64>) -> tensor + %dim_3003 = tensor.dim %8524, %c0 : tensor + %8525 = arith.index_cast %dim_3003 : index to i64 + %from_elements_3004 = tensor.from_elements %c1_i64, %8525, %c4096_i64 : tensor<3xi64> + %8526 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3004, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3005 = tensor.dim %8526, %c1 : tensor<1x?x4096xi64> + %8527 = arith.index_cast %dim_3005 : index to i64 + %from_elements_3006 = tensor.from_elements %c1_i64, %8527, %c4096_i64, %c1_i64 : tensor<4xi64> + %8528 = stablehlo.dynamic_reshape %8526, %from_elements_3006 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8529 = stablehlo.dynamic_broadcast_in_dim %8524, %from_elements_3004, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3007 = tensor.dim %8529, %c1 : tensor<1x?x4096xi64> + %8530 = arith.index_cast %dim_3007 : index to i64 + %from_elements_3008 = tensor.from_elements %c1_i64, %8530, %c4096_i64, %c1_i64 : tensor<4xi64> + %8531 = stablehlo.dynamic_reshape %8529, %from_elements_3008 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8532 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3004, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3009 = tensor.dim %8532, %c1 : tensor<1x?x4096xi64> + %8533 = arith.index_cast %dim_3009 : index to i64 + %from_elements_3010 = tensor.from_elements %c1_i64, %8533, %c4096_i64, %c1_i64 : tensor<4xi64> + %8534 = stablehlo.dynamic_reshape %8532, %from_elements_3010 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8535 = stablehlo.concatenate %8528, %8531, %8534, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8536 = "stablehlo.gather"(%8143, %8535) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8537 = shape.shape_of %8536 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8538 = shape.num_elements %8537 : tensor<3xindex> -> index + %8539 = stablehlo.compute_reshape_shape %8538, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8540 = stablehlo.dynamic_reshape %8536, %8539 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8541 = stablehlo.dot %8540, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8542 = stablehlo.logistic %8541 : tensor + %8543 = shape.shape_of %8542 : tensor -> tensor<2xindex> + %8544 = shape.shape_of %8541 : tensor -> tensor<2xindex> + %8545 = shape.cstr_broadcastable %8543, %8544 : tensor<2xindex>, tensor<2xindex> + %8546 = shape.assuming %8545 -> (tensor) { + %19688 = shape.broadcast %8543, %8544 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8542, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8541, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8547 = shape.shape_of %8546 : tensor -> tensor<2xindex> + %8548 = shape.cstr_broadcastable %8547, %8544 : tensor<2xindex>, tensor<2xindex> + %8549 = shape.assuming %8548 -> (tensor) { + %19688 = shape.broadcast %8547, %8544 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8546, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8541, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8550 = stablehlo.dot %8549, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3011 = tensor.dim %8522, %c0 : tensor + %8551 = arith.index_cast %dim_3011 : index to i64 + %from_elements_3012 = tensor.from_elements %8551, %c1_i64 : tensor<2xi64> + %8552 = stablehlo.dynamic_reshape %8522, %from_elements_3012 : (tensor, tensor<2xi64>) -> tensor + %dim_3013 = tensor.dim %8519, %c0 : tensor + %8553 = arith.index_cast %dim_3013 : index to i64 + %from_elements_3014 = tensor.from_elements %8553, %c1_i64 : tensor<2xi64> + %8554 = stablehlo.dynamic_reshape %8519, %from_elements_3014 : (tensor, tensor<2xi64>) -> tensor + %8555 = stablehlo.concatenate %8552, %8554, dim = 1 : (tensor, tensor) -> tensor + %8556 = "stablehlo.gather"(%8172, %8555) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8557 = shape.shape_of %8550 : tensor -> tensor<2xindex> + %8558 = shape.shape_of %8556 : tensor -> tensor<2xindex> + %8559 = shape.cstr_broadcastable %8557, %8558 : tensor<2xindex>, tensor<2xindex> + %8560 = shape.assuming %8559 -> (tensor) { + %19688 = shape.broadcast %8557, %8558 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8550, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8556, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8561 = shape.shape_of %8560 : tensor -> tensor<2xindex> + %8562 = stablehlo.dynamic_broadcast_in_dim %8560, %8561, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8563 = stablehlo.dynamic_broadcast_in_dim %213, %8561, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8564 = stablehlo.multiply %8562, %8563 : tensor + %dim_3015 = tensor.dim %8524, %c0 : tensor + %8565 = arith.index_cast %dim_3015 : index to i64 + %dim_3016 = tensor.dim %8560, %c0 : tensor + %8566 = arith.index_cast %dim_3016 : index to i64 + %8567 = arith.maxsi %8565, %8566 : i64 + %8568 = arith.index_cast %8567 : i64 to index + %from_elements_3017 = tensor.from_elements %8568, %c4096 : tensor<2xindex> + %8569 = stablehlo.dynamic_broadcast_in_dim %8524, %from_elements_3017, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3018 = tensor.dim %8569, %c0 : tensor + %8570 = arith.index_cast %dim_3018 : index to i64 + %from_elements_3019 = tensor.from_elements %8570, %c4096_i64 : tensor<2xi64> + %8571 = stablehlo.real_dynamic_slice %8564, %c_22, %from_elements_3019, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3020 = tensor.from_elements %8570, %c4096_i64, %c1_i64 : tensor<3xi64> + %8572 = stablehlo.dynamic_reshape %8569, %from_elements_3020 : (tensor, tensor<3xi64>) -> tensor + %8573 = stablehlo.dynamic_iota %from_elements_3020, dim = 1 : (tensor<3xi64>) -> tensor + %8574 = stablehlo.concatenate %8572, %8573, dim = 2 : (tensor, tensor) -> tensor + %8575 = "stablehlo.scatter"(%8512, %8574, %8571) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8576 = stablehlo.slice %8132 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8577 = stablehlo.reshape %8576 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8578 = stablehlo.custom_call @byteir.non_zero(%8577) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3021 = tensor.dim %8578, %c0 : tensor + %8579 = arith.index_cast %dim_3021 : index to i64 + %from_elements_3022 = tensor.from_elements %8579, %c1_i64 : tensor<2xi64> + %8580 = stablehlo.real_dynamic_slice %8578, %c_22, %from_elements_3022, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3023 = tensor.dim %8580, %c0 : tensor + %8581 = arith.index_cast %dim_3023 : index to i64 + %from_elements_3024 = tensor.from_elements %8581 : tensor<1xi64> + %8582 = stablehlo.dynamic_reshape %8580, %from_elements_3024 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3025 = tensor.from_elements %8579, %c2_i64 : tensor<2xi64> + %8583 = stablehlo.real_dynamic_slice %8578, %c_24, %from_elements_3025, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3026 = tensor.dim %8583, %c0 : tensor + %8584 = arith.index_cast %dim_3026 : index to i64 + %from_elements_3027 = tensor.from_elements %8584 : tensor<1xi64> + %8585 = stablehlo.dynamic_reshape %8583, %from_elements_3027 : (tensor, tensor<1xi64>) -> tensor + %dim_3028 = tensor.dim %8585, %c0 : tensor + %8586 = arith.index_cast %dim_3028 : index to i64 + %from_elements_3029 = tensor.from_elements %8586, %c1_i64 : tensor<2xi64> + %8587 = stablehlo.dynamic_reshape %8585, %from_elements_3029 : (tensor, tensor<2xi64>) -> tensor + %dim_3030 = tensor.dim %8587, %c0 : tensor + %8588 = arith.index_cast %dim_3030 : index to i64 + %from_elements_3031 = tensor.from_elements %c1_i64, %8588, %c4096_i64 : tensor<3xi64> + %8589 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3031, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3032 = tensor.dim %8589, %c1 : tensor<1x?x4096xi64> + %8590 = arith.index_cast %dim_3032 : index to i64 + %from_elements_3033 = tensor.from_elements %c1_i64, %8590, %c4096_i64, %c1_i64 : tensor<4xi64> + %8591 = stablehlo.dynamic_reshape %8589, %from_elements_3033 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8592 = stablehlo.dynamic_broadcast_in_dim %8587, %from_elements_3031, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3034 = tensor.dim %8592, %c1 : tensor<1x?x4096xi64> + %8593 = arith.index_cast %dim_3034 : index to i64 + %from_elements_3035 = tensor.from_elements %c1_i64, %8593, %c4096_i64, %c1_i64 : tensor<4xi64> + %8594 = stablehlo.dynamic_reshape %8592, %from_elements_3035 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8595 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3031, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3036 = tensor.dim %8595, %c1 : tensor<1x?x4096xi64> + %8596 = arith.index_cast %dim_3036 : index to i64 + %from_elements_3037 = tensor.from_elements %c1_i64, %8596, %c4096_i64, %c1_i64 : tensor<4xi64> + %8597 = stablehlo.dynamic_reshape %8595, %from_elements_3037 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8598 = stablehlo.concatenate %8591, %8594, %8597, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8599 = "stablehlo.gather"(%8143, %8598) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8600 = shape.shape_of %8599 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8601 = shape.num_elements %8600 : tensor<3xindex> -> index + %8602 = stablehlo.compute_reshape_shape %8601, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8603 = stablehlo.dynamic_reshape %8599, %8602 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8604 = stablehlo.dot %8603, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8605 = stablehlo.logistic %8604 : tensor + %8606 = shape.shape_of %8605 : tensor -> tensor<2xindex> + %8607 = shape.shape_of %8604 : tensor -> tensor<2xindex> + %8608 = shape.cstr_broadcastable %8606, %8607 : tensor<2xindex>, tensor<2xindex> + %8609 = shape.assuming %8608 -> (tensor) { + %19688 = shape.broadcast %8606, %8607 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8605, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8604, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8610 = shape.shape_of %8609 : tensor -> tensor<2xindex> + %8611 = shape.cstr_broadcastable %8610, %8607 : tensor<2xindex>, tensor<2xindex> + %8612 = shape.assuming %8611 -> (tensor) { + %19688 = shape.broadcast %8610, %8607 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8609, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8604, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8613 = stablehlo.dot %8612, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3038 = tensor.dim %8585, %c0 : tensor + %8614 = arith.index_cast %dim_3038 : index to i64 + %from_elements_3039 = tensor.from_elements %8614, %c1_i64 : tensor<2xi64> + %8615 = stablehlo.dynamic_reshape %8585, %from_elements_3039 : (tensor, tensor<2xi64>) -> tensor + %dim_3040 = tensor.dim %8582, %c0 : tensor + %8616 = arith.index_cast %dim_3040 : index to i64 + %from_elements_3041 = tensor.from_elements %8616, %c1_i64 : tensor<2xi64> + %8617 = stablehlo.dynamic_reshape %8582, %from_elements_3041 : (tensor, tensor<2xi64>) -> tensor + %8618 = stablehlo.concatenate %8615, %8617, dim = 1 : (tensor, tensor) -> tensor + %8619 = "stablehlo.gather"(%8172, %8618) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8620 = shape.shape_of %8613 : tensor -> tensor<2xindex> + %8621 = shape.shape_of %8619 : tensor -> tensor<2xindex> + %8622 = shape.cstr_broadcastable %8620, %8621 : tensor<2xindex>, tensor<2xindex> + %8623 = shape.assuming %8622 -> (tensor) { + %19688 = shape.broadcast %8620, %8621 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8613, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8619, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8624 = shape.shape_of %8623 : tensor -> tensor<2xindex> + %8625 = stablehlo.dynamic_broadcast_in_dim %8623, %8624, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8626 = stablehlo.dynamic_broadcast_in_dim %213, %8624, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8627 = stablehlo.multiply %8625, %8626 : tensor + %dim_3042 = tensor.dim %8587, %c0 : tensor + %8628 = arith.index_cast %dim_3042 : index to i64 + %dim_3043 = tensor.dim %8623, %c0 : tensor + %8629 = arith.index_cast %dim_3043 : index to i64 + %8630 = arith.maxsi %8628, %8629 : i64 + %8631 = arith.index_cast %8630 : i64 to index + %from_elements_3044 = tensor.from_elements %8631, %c4096 : tensor<2xindex> + %8632 = stablehlo.dynamic_broadcast_in_dim %8587, %from_elements_3044, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3045 = tensor.dim %8632, %c0 : tensor + %8633 = arith.index_cast %dim_3045 : index to i64 + %from_elements_3046 = tensor.from_elements %8633, %c4096_i64 : tensor<2xi64> + %8634 = stablehlo.real_dynamic_slice %8627, %c_22, %from_elements_3046, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3047 = tensor.from_elements %8633, %c4096_i64, %c1_i64 : tensor<3xi64> + %8635 = stablehlo.dynamic_reshape %8632, %from_elements_3047 : (tensor, tensor<3xi64>) -> tensor + %8636 = stablehlo.dynamic_iota %from_elements_3047, dim = 1 : (tensor<3xi64>) -> tensor + %8637 = stablehlo.concatenate %8635, %8636, dim = 2 : (tensor, tensor) -> tensor + %8638 = "stablehlo.scatter"(%8575, %8637, %8634) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8639 = stablehlo.reshape %8638 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %8640 = stablehlo.add %8105, %8639 : tensor<3x1x4096xf32> + %8641 = stablehlo.broadcast_in_dim %8640, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %8642 = stablehlo.power %8641, %15 : tensor<3x1x4096xf32> + %8643 = stablehlo.reduce(%8642 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %8644 = stablehlo.reshape %8643 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %8645 = stablehlo.broadcast_in_dim %8644, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %8646 = stablehlo.divide %8645, %21 : tensor<3x1x1xf32> + %8647 = stablehlo.broadcast_in_dim %8646, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %8648 = stablehlo.add %8647, %25 : tensor<3x1x1xf32> + %8649 = stablehlo.rsqrt %8648 : tensor<3x1x1xf32> + %8650 = stablehlo.broadcast_in_dim %8649, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %8651 = stablehlo.multiply %8641, %8650 : tensor<3x1x4096xf32> + %8652 = stablehlo.broadcast_in_dim %8651, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %8653 = stablehlo.multiply %8652, %31 : tensor<3x1x4096xf32> + %8654 = stablehlo.reshape %8653 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %8655 = stablehlo.dot %8654, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %8656 = stablehlo.reshape %8655 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %8657 = stablehlo.dot %8654, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %8658 = stablehlo.reshape %8657 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %8659 = stablehlo.reshape %8656 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %8660 = stablehlo.transpose %8659, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %8661 = stablehlo.reshape %8658 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %8662 = stablehlo.transpose %8661, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %8663 = stablehlo.slice %arg28 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %8664 = stablehlo.slice %arg29 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %8665 = "stablehlo.gather"(%8663, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %8666 = stablehlo.reshape %8665 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %8667 = "stablehlo.gather"(%8664, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %8668 = stablehlo.reshape %8667 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %8669 = stablehlo.broadcast_in_dim %8660, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %8670 = stablehlo.broadcast_in_dim %8666, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %8671 = stablehlo.multiply %8669, %8670 : tensor<3x32x1x128xf32> + %8672 = stablehlo.slice %8660 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %8673 = stablehlo.slice %8660 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %8674 = stablehlo.negate %8673 : tensor<3x32x1x64xf32> + %8675 = stablehlo.concatenate %8674, %8672, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %8676 = stablehlo.broadcast_in_dim %8675, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %8677 = stablehlo.broadcast_in_dim %8668, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %8678 = stablehlo.multiply %8676, %8677 : tensor<3x32x1x128xf32> + %8679 = stablehlo.add %8671, %8678 : tensor<3x32x1x128xf32> + %8680 = stablehlo.broadcast_in_dim %8662, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %8681 = stablehlo.broadcast_in_dim %8666, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %8682 = stablehlo.multiply %8680, %8681 : tensor<3x8x1x128xf32> + %8683 = stablehlo.slice %8662 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %8684 = stablehlo.slice %8662 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %8685 = stablehlo.negate %8684 : tensor<3x8x1x64xf32> + %8686 = stablehlo.concatenate %8685, %8683, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %8687 = stablehlo.broadcast_in_dim %8686, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %8688 = stablehlo.broadcast_in_dim %8668, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %8689 = stablehlo.multiply %8687, %8688 : tensor<3x8x1x128xf32> + %8690 = stablehlo.add %8682, %8689 : tensor<3x8x1x128xf32> + %8691 = stablehlo.concatenate %arg93, %8690, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %8692 = stablehlo.concatenate %arg94, %8662, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %8693 = stablehlo.reshape %8691 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %8694 = stablehlo.broadcast_in_dim %8693, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %8695 = stablehlo.reshape %8694 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %8696 = stablehlo.reshape %8692 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %8697 = stablehlo.broadcast_in_dim %8696, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %8698 = stablehlo.reshape %8697 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %8699 = stablehlo.transpose %8695, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %8700 = stablehlo.reshape %8679 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %8701 = stablehlo.reshape %8699 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %8702 = stablehlo.broadcast_in_dim %8701, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %8703 = stablehlo.dot_general %8700, %8702, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %8704 = stablehlo.reshape %8703 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %8705 = stablehlo.broadcast_in_dim %8704, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %8706 = stablehlo.divide %8705, %89 : tensor<3x32x1x8xf32> + %8707 = stablehlo.custom_call @byteir.softmax(%8706) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %8708 = stablehlo.reshape %8707 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %8709 = stablehlo.reshape %8698 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %8710 = stablehlo.broadcast_in_dim %8709, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %8711 = stablehlo.dot_general %8708, %8710, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %8712 = stablehlo.reshape %8711 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %8713 = stablehlo.transpose %8712, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %8714 = stablehlo.reshape %8713 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %8715 = stablehlo.reshape %8714 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %8716 = stablehlo.dot %8715, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %8717 = stablehlo.reshape %8716 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %8718 = stablehlo.add %8640, %8717 : tensor<3x1x4096xf32> + %8719 = stablehlo.broadcast_in_dim %8718, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %8720 = stablehlo.power %8719, %15 : tensor<3x1x4096xf32> + %8721 = stablehlo.reduce(%8720 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %8722 = stablehlo.reshape %8721 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %8723 = stablehlo.broadcast_in_dim %8722, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %8724 = stablehlo.divide %8723, %21 : tensor<3x1x1xf32> + %8725 = stablehlo.broadcast_in_dim %8724, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %8726 = stablehlo.add %8725, %25 : tensor<3x1x1xf32> + %8727 = stablehlo.rsqrt %8726 : tensor<3x1x1xf32> + %8728 = stablehlo.broadcast_in_dim %8727, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %8729 = stablehlo.multiply %8719, %8728 : tensor<3x1x4096xf32> + %8730 = stablehlo.broadcast_in_dim %8729, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %8731 = stablehlo.multiply %8730, %31 : tensor<3x1x4096xf32> + %8732 = stablehlo.reshape %8731 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %8733 = stablehlo.dot %8732, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %8734 = stablehlo.custom_call @byteir.softmax(%8733) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %8735:2 = stablehlo.custom_call @byteir.top_k(%8734) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %8736 = stablehlo.reduce(%8735#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %8737 = stablehlo.reshape %8736 : (tensor<3xf32>) -> tensor<3x1xf32> + %8738 = stablehlo.broadcast_in_dim %8735#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %8739 = stablehlo.broadcast_in_dim %8737, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %8740 = stablehlo.divide %8738, %8739 : tensor<3x2xf32> + %8741 = stablehlo.reshape %8735#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %8742 = stablehlo.broadcast_in_dim %8741, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %8743 = stablehlo.compare EQ, %8742, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %8744 = stablehlo.convert %8743 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %8745 = stablehlo.transpose %8744, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %8746 = stablehlo.slice %8745 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8747 = stablehlo.reshape %8746 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8748 = stablehlo.custom_call @byteir.non_zero(%8747) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3048 = tensor.dim %8748, %c0 : tensor + %8749 = arith.index_cast %dim_3048 : index to i64 + %from_elements_3049 = tensor.from_elements %8749, %c1_i64 : tensor<2xi64> + %8750 = stablehlo.real_dynamic_slice %8748, %c_22, %from_elements_3049, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3050 = tensor.dim %8750, %c0 : tensor + %8751 = arith.index_cast %dim_3050 : index to i64 + %from_elements_3051 = tensor.from_elements %8751 : tensor<1xi64> + %8752 = stablehlo.dynamic_reshape %8750, %from_elements_3051 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3052 = tensor.from_elements %8749, %c2_i64 : tensor<2xi64> + %8753 = stablehlo.real_dynamic_slice %8748, %c_24, %from_elements_3052, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3053 = tensor.dim %8753, %c0 : tensor + %8754 = arith.index_cast %dim_3053 : index to i64 + %from_elements_3054 = tensor.from_elements %8754 : tensor<1xi64> + %8755 = stablehlo.dynamic_reshape %8753, %from_elements_3054 : (tensor, tensor<1xi64>) -> tensor + %8756 = stablehlo.reshape %8732 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_3055 = tensor.dim %8755, %c0 : tensor + %8757 = arith.index_cast %dim_3055 : index to i64 + %from_elements_3056 = tensor.from_elements %8757, %c1_i64 : tensor<2xi64> + %8758 = stablehlo.dynamic_reshape %8755, %from_elements_3056 : (tensor, tensor<2xi64>) -> tensor + %dim_3057 = tensor.dim %8758, %c0 : tensor + %8759 = arith.index_cast %dim_3057 : index to i64 + %from_elements_3058 = tensor.from_elements %c1_i64, %8759, %c4096_i64 : tensor<3xi64> + %8760 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3058, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3059 = tensor.dim %8760, %c1 : tensor<1x?x4096xi64> + %8761 = arith.index_cast %dim_3059 : index to i64 + %from_elements_3060 = tensor.from_elements %c1_i64, %8761, %c4096_i64, %c1_i64 : tensor<4xi64> + %8762 = stablehlo.dynamic_reshape %8760, %from_elements_3060 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8763 = stablehlo.dynamic_broadcast_in_dim %8758, %from_elements_3058, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3061 = tensor.dim %8763, %c1 : tensor<1x?x4096xi64> + %8764 = arith.index_cast %dim_3061 : index to i64 + %from_elements_3062 = tensor.from_elements %c1_i64, %8764, %c4096_i64, %c1_i64 : tensor<4xi64> + %8765 = stablehlo.dynamic_reshape %8763, %from_elements_3062 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8766 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3058, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3063 = tensor.dim %8766, %c1 : tensor<1x?x4096xi64> + %8767 = arith.index_cast %dim_3063 : index to i64 + %from_elements_3064 = tensor.from_elements %c1_i64, %8767, %c4096_i64, %c1_i64 : tensor<4xi64> + %8768 = stablehlo.dynamic_reshape %8766, %from_elements_3064 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8769 = stablehlo.concatenate %8762, %8765, %8768, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8770 = "stablehlo.gather"(%8756, %8769) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8771 = shape.shape_of %8770 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8772 = shape.num_elements %8771 : tensor<3xindex> -> index + %8773 = stablehlo.compute_reshape_shape %8772, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8774 = stablehlo.dynamic_reshape %8770, %8773 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8775 = stablehlo.dot %8774, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8776 = stablehlo.logistic %8775 : tensor + %8777 = shape.shape_of %8776 : tensor -> tensor<2xindex> + %8778 = shape.shape_of %8775 : tensor -> tensor<2xindex> + %8779 = shape.cstr_broadcastable %8777, %8778 : tensor<2xindex>, tensor<2xindex> + %8780 = shape.assuming %8779 -> (tensor) { + %19688 = shape.broadcast %8777, %8778 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8776, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8775, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8781 = shape.shape_of %8780 : tensor -> tensor<2xindex> + %8782 = shape.cstr_broadcastable %8781, %8778 : tensor<2xindex>, tensor<2xindex> + %8783 = shape.assuming %8782 -> (tensor) { + %19688 = shape.broadcast %8781, %8778 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8780, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8775, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8784 = stablehlo.dot %8783, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %8785 = stablehlo.reshape %8740 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_3065 = tensor.dim %8755, %c0 : tensor + %8786 = arith.index_cast %dim_3065 : index to i64 + %from_elements_3066 = tensor.from_elements %8786, %c1_i64 : tensor<2xi64> + %8787 = stablehlo.dynamic_reshape %8755, %from_elements_3066 : (tensor, tensor<2xi64>) -> tensor + %dim_3067 = tensor.dim %8752, %c0 : tensor + %8788 = arith.index_cast %dim_3067 : index to i64 + %from_elements_3068 = tensor.from_elements %8788, %c1_i64 : tensor<2xi64> + %8789 = stablehlo.dynamic_reshape %8752, %from_elements_3068 : (tensor, tensor<2xi64>) -> tensor + %8790 = stablehlo.concatenate %8787, %8789, dim = 1 : (tensor, tensor) -> tensor + %8791 = "stablehlo.gather"(%8785, %8790) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8792 = shape.shape_of %8784 : tensor -> tensor<2xindex> + %8793 = shape.shape_of %8791 : tensor -> tensor<2xindex> + %8794 = shape.cstr_broadcastable %8792, %8793 : tensor<2xindex>, tensor<2xindex> + %8795 = shape.assuming %8794 -> (tensor) { + %19688 = shape.broadcast %8792, %8793 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8784, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8791, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8796 = shape.shape_of %8795 : tensor -> tensor<2xindex> + %8797 = stablehlo.dynamic_broadcast_in_dim %8795, %8796, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8798 = stablehlo.dynamic_broadcast_in_dim %213, %8796, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8799 = stablehlo.multiply %8797, %8798 : tensor + %dim_3069 = tensor.dim %8758, %c0 : tensor + %8800 = arith.index_cast %dim_3069 : index to i64 + %dim_3070 = tensor.dim %8795, %c0 : tensor + %8801 = arith.index_cast %dim_3070 : index to i64 + %8802 = arith.maxsi %8800, %8801 : i64 + %8803 = arith.index_cast %8802 : i64 to index + %from_elements_3071 = tensor.from_elements %8803, %c4096 : tensor<2xindex> + %8804 = stablehlo.dynamic_broadcast_in_dim %8758, %from_elements_3071, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3072 = tensor.dim %8804, %c0 : tensor + %8805 = arith.index_cast %dim_3072 : index to i64 + %from_elements_3073 = tensor.from_elements %8805, %c4096_i64 : tensor<2xi64> + %8806 = stablehlo.real_dynamic_slice %8799, %c_22, %from_elements_3073, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3074 = tensor.from_elements %8805, %c4096_i64, %c1_i64 : tensor<3xi64> + %8807 = stablehlo.dynamic_reshape %8804, %from_elements_3074 : (tensor, tensor<3xi64>) -> tensor + %8808 = stablehlo.dynamic_iota %from_elements_3074, dim = 1 : (tensor<3xi64>) -> tensor + %8809 = stablehlo.concatenate %8807, %8808, dim = 2 : (tensor, tensor) -> tensor + %8810 = "stablehlo.scatter"(%cst_2, %8809, %8806) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8811 = stablehlo.slice %8745 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8812 = stablehlo.reshape %8811 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8813 = stablehlo.custom_call @byteir.non_zero(%8812) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3075 = tensor.dim %8813, %c0 : tensor + %8814 = arith.index_cast %dim_3075 : index to i64 + %from_elements_3076 = tensor.from_elements %8814, %c1_i64 : tensor<2xi64> + %8815 = stablehlo.real_dynamic_slice %8813, %c_22, %from_elements_3076, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3077 = tensor.dim %8815, %c0 : tensor + %8816 = arith.index_cast %dim_3077 : index to i64 + %from_elements_3078 = tensor.from_elements %8816 : tensor<1xi64> + %8817 = stablehlo.dynamic_reshape %8815, %from_elements_3078 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3079 = tensor.from_elements %8814, %c2_i64 : tensor<2xi64> + %8818 = stablehlo.real_dynamic_slice %8813, %c_24, %from_elements_3079, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3080 = tensor.dim %8818, %c0 : tensor + %8819 = arith.index_cast %dim_3080 : index to i64 + %from_elements_3081 = tensor.from_elements %8819 : tensor<1xi64> + %8820 = stablehlo.dynamic_reshape %8818, %from_elements_3081 : (tensor, tensor<1xi64>) -> tensor + %dim_3082 = tensor.dim %8820, %c0 : tensor + %8821 = arith.index_cast %dim_3082 : index to i64 + %from_elements_3083 = tensor.from_elements %8821, %c1_i64 : tensor<2xi64> + %8822 = stablehlo.dynamic_reshape %8820, %from_elements_3083 : (tensor, tensor<2xi64>) -> tensor + %dim_3084 = tensor.dim %8822, %c0 : tensor + %8823 = arith.index_cast %dim_3084 : index to i64 + %from_elements_3085 = tensor.from_elements %c1_i64, %8823, %c4096_i64 : tensor<3xi64> + %8824 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3085, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3086 = tensor.dim %8824, %c1 : tensor<1x?x4096xi64> + %8825 = arith.index_cast %dim_3086 : index to i64 + %from_elements_3087 = tensor.from_elements %c1_i64, %8825, %c4096_i64, %c1_i64 : tensor<4xi64> + %8826 = stablehlo.dynamic_reshape %8824, %from_elements_3087 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8827 = stablehlo.dynamic_broadcast_in_dim %8822, %from_elements_3085, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3088 = tensor.dim %8827, %c1 : tensor<1x?x4096xi64> + %8828 = arith.index_cast %dim_3088 : index to i64 + %from_elements_3089 = tensor.from_elements %c1_i64, %8828, %c4096_i64, %c1_i64 : tensor<4xi64> + %8829 = stablehlo.dynamic_reshape %8827, %from_elements_3089 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8830 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3085, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3090 = tensor.dim %8830, %c1 : tensor<1x?x4096xi64> + %8831 = arith.index_cast %dim_3090 : index to i64 + %from_elements_3091 = tensor.from_elements %c1_i64, %8831, %c4096_i64, %c1_i64 : tensor<4xi64> + %8832 = stablehlo.dynamic_reshape %8830, %from_elements_3091 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8833 = stablehlo.concatenate %8826, %8829, %8832, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8834 = "stablehlo.gather"(%8756, %8833) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8835 = shape.shape_of %8834 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8836 = shape.num_elements %8835 : tensor<3xindex> -> index + %8837 = stablehlo.compute_reshape_shape %8836, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8838 = stablehlo.dynamic_reshape %8834, %8837 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8839 = stablehlo.dot %8838, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8840 = stablehlo.logistic %8839 : tensor + %8841 = shape.shape_of %8840 : tensor -> tensor<2xindex> + %8842 = shape.shape_of %8839 : tensor -> tensor<2xindex> + %8843 = shape.cstr_broadcastable %8841, %8842 : tensor<2xindex>, tensor<2xindex> + %8844 = shape.assuming %8843 -> (tensor) { + %19688 = shape.broadcast %8841, %8842 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8840, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8839, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8845 = shape.shape_of %8844 : tensor -> tensor<2xindex> + %8846 = shape.cstr_broadcastable %8845, %8842 : tensor<2xindex>, tensor<2xindex> + %8847 = shape.assuming %8846 -> (tensor) { + %19688 = shape.broadcast %8845, %8842 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8844, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8839, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8848 = stablehlo.dot %8847, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3092 = tensor.dim %8820, %c0 : tensor + %8849 = arith.index_cast %dim_3092 : index to i64 + %from_elements_3093 = tensor.from_elements %8849, %c1_i64 : tensor<2xi64> + %8850 = stablehlo.dynamic_reshape %8820, %from_elements_3093 : (tensor, tensor<2xi64>) -> tensor + %dim_3094 = tensor.dim %8817, %c0 : tensor + %8851 = arith.index_cast %dim_3094 : index to i64 + %from_elements_3095 = tensor.from_elements %8851, %c1_i64 : tensor<2xi64> + %8852 = stablehlo.dynamic_reshape %8817, %from_elements_3095 : (tensor, tensor<2xi64>) -> tensor + %8853 = stablehlo.concatenate %8850, %8852, dim = 1 : (tensor, tensor) -> tensor + %8854 = "stablehlo.gather"(%8785, %8853) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8855 = shape.shape_of %8848 : tensor -> tensor<2xindex> + %8856 = shape.shape_of %8854 : tensor -> tensor<2xindex> + %8857 = shape.cstr_broadcastable %8855, %8856 : tensor<2xindex>, tensor<2xindex> + %8858 = shape.assuming %8857 -> (tensor) { + %19688 = shape.broadcast %8855, %8856 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8848, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8854, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8859 = shape.shape_of %8858 : tensor -> tensor<2xindex> + %8860 = stablehlo.dynamic_broadcast_in_dim %8858, %8859, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8861 = stablehlo.dynamic_broadcast_in_dim %213, %8859, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8862 = stablehlo.multiply %8860, %8861 : tensor + %dim_3096 = tensor.dim %8822, %c0 : tensor + %8863 = arith.index_cast %dim_3096 : index to i64 + %dim_3097 = tensor.dim %8858, %c0 : tensor + %8864 = arith.index_cast %dim_3097 : index to i64 + %8865 = arith.maxsi %8863, %8864 : i64 + %8866 = arith.index_cast %8865 : i64 to index + %from_elements_3098 = tensor.from_elements %8866, %c4096 : tensor<2xindex> + %8867 = stablehlo.dynamic_broadcast_in_dim %8822, %from_elements_3098, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3099 = tensor.dim %8867, %c0 : tensor + %8868 = arith.index_cast %dim_3099 : index to i64 + %from_elements_3100 = tensor.from_elements %8868, %c4096_i64 : tensor<2xi64> + %8869 = stablehlo.real_dynamic_slice %8862, %c_22, %from_elements_3100, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3101 = tensor.from_elements %8868, %c4096_i64, %c1_i64 : tensor<3xi64> + %8870 = stablehlo.dynamic_reshape %8867, %from_elements_3101 : (tensor, tensor<3xi64>) -> tensor + %8871 = stablehlo.dynamic_iota %from_elements_3101, dim = 1 : (tensor<3xi64>) -> tensor + %8872 = stablehlo.concatenate %8870, %8871, dim = 2 : (tensor, tensor) -> tensor + %8873 = "stablehlo.scatter"(%8810, %8872, %8869) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8874 = stablehlo.slice %8745 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8875 = stablehlo.reshape %8874 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8876 = stablehlo.custom_call @byteir.non_zero(%8875) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3102 = tensor.dim %8876, %c0 : tensor + %8877 = arith.index_cast %dim_3102 : index to i64 + %from_elements_3103 = tensor.from_elements %8877, %c1_i64 : tensor<2xi64> + %8878 = stablehlo.real_dynamic_slice %8876, %c_22, %from_elements_3103, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3104 = tensor.dim %8878, %c0 : tensor + %8879 = arith.index_cast %dim_3104 : index to i64 + %from_elements_3105 = tensor.from_elements %8879 : tensor<1xi64> + %8880 = stablehlo.dynamic_reshape %8878, %from_elements_3105 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3106 = tensor.from_elements %8877, %c2_i64 : tensor<2xi64> + %8881 = stablehlo.real_dynamic_slice %8876, %c_24, %from_elements_3106, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3107 = tensor.dim %8881, %c0 : tensor + %8882 = arith.index_cast %dim_3107 : index to i64 + %from_elements_3108 = tensor.from_elements %8882 : tensor<1xi64> + %8883 = stablehlo.dynamic_reshape %8881, %from_elements_3108 : (tensor, tensor<1xi64>) -> tensor + %dim_3109 = tensor.dim %8883, %c0 : tensor + %8884 = arith.index_cast %dim_3109 : index to i64 + %from_elements_3110 = tensor.from_elements %8884, %c1_i64 : tensor<2xi64> + %8885 = stablehlo.dynamic_reshape %8883, %from_elements_3110 : (tensor, tensor<2xi64>) -> tensor + %dim_3111 = tensor.dim %8885, %c0 : tensor + %8886 = arith.index_cast %dim_3111 : index to i64 + %from_elements_3112 = tensor.from_elements %c1_i64, %8886, %c4096_i64 : tensor<3xi64> + %8887 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3112, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3113 = tensor.dim %8887, %c1 : tensor<1x?x4096xi64> + %8888 = arith.index_cast %dim_3113 : index to i64 + %from_elements_3114 = tensor.from_elements %c1_i64, %8888, %c4096_i64, %c1_i64 : tensor<4xi64> + %8889 = stablehlo.dynamic_reshape %8887, %from_elements_3114 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8890 = stablehlo.dynamic_broadcast_in_dim %8885, %from_elements_3112, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3115 = tensor.dim %8890, %c1 : tensor<1x?x4096xi64> + %8891 = arith.index_cast %dim_3115 : index to i64 + %from_elements_3116 = tensor.from_elements %c1_i64, %8891, %c4096_i64, %c1_i64 : tensor<4xi64> + %8892 = stablehlo.dynamic_reshape %8890, %from_elements_3116 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8893 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3112, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3117 = tensor.dim %8893, %c1 : tensor<1x?x4096xi64> + %8894 = arith.index_cast %dim_3117 : index to i64 + %from_elements_3118 = tensor.from_elements %c1_i64, %8894, %c4096_i64, %c1_i64 : tensor<4xi64> + %8895 = stablehlo.dynamic_reshape %8893, %from_elements_3118 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8896 = stablehlo.concatenate %8889, %8892, %8895, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8897 = "stablehlo.gather"(%8756, %8896) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8898 = shape.shape_of %8897 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8899 = shape.num_elements %8898 : tensor<3xindex> -> index + %8900 = stablehlo.compute_reshape_shape %8899, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8901 = stablehlo.dynamic_reshape %8897, %8900 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8902 = stablehlo.dot %8901, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8903 = stablehlo.logistic %8902 : tensor + %8904 = shape.shape_of %8903 : tensor -> tensor<2xindex> + %8905 = shape.shape_of %8902 : tensor -> tensor<2xindex> + %8906 = shape.cstr_broadcastable %8904, %8905 : tensor<2xindex>, tensor<2xindex> + %8907 = shape.assuming %8906 -> (tensor) { + %19688 = shape.broadcast %8904, %8905 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8903, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8902, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8908 = shape.shape_of %8907 : tensor -> tensor<2xindex> + %8909 = shape.cstr_broadcastable %8908, %8905 : tensor<2xindex>, tensor<2xindex> + %8910 = shape.assuming %8909 -> (tensor) { + %19688 = shape.broadcast %8908, %8905 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8907, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8902, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8911 = stablehlo.dot %8910, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3119 = tensor.dim %8883, %c0 : tensor + %8912 = arith.index_cast %dim_3119 : index to i64 + %from_elements_3120 = tensor.from_elements %8912, %c1_i64 : tensor<2xi64> + %8913 = stablehlo.dynamic_reshape %8883, %from_elements_3120 : (tensor, tensor<2xi64>) -> tensor + %dim_3121 = tensor.dim %8880, %c0 : tensor + %8914 = arith.index_cast %dim_3121 : index to i64 + %from_elements_3122 = tensor.from_elements %8914, %c1_i64 : tensor<2xi64> + %8915 = stablehlo.dynamic_reshape %8880, %from_elements_3122 : (tensor, tensor<2xi64>) -> tensor + %8916 = stablehlo.concatenate %8913, %8915, dim = 1 : (tensor, tensor) -> tensor + %8917 = "stablehlo.gather"(%8785, %8916) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8918 = shape.shape_of %8911 : tensor -> tensor<2xindex> + %8919 = shape.shape_of %8917 : tensor -> tensor<2xindex> + %8920 = shape.cstr_broadcastable %8918, %8919 : tensor<2xindex>, tensor<2xindex> + %8921 = shape.assuming %8920 -> (tensor) { + %19688 = shape.broadcast %8918, %8919 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8911, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8917, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8922 = shape.shape_of %8921 : tensor -> tensor<2xindex> + %8923 = stablehlo.dynamic_broadcast_in_dim %8921, %8922, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8924 = stablehlo.dynamic_broadcast_in_dim %213, %8922, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8925 = stablehlo.multiply %8923, %8924 : tensor + %dim_3123 = tensor.dim %8885, %c0 : tensor + %8926 = arith.index_cast %dim_3123 : index to i64 + %dim_3124 = tensor.dim %8921, %c0 : tensor + %8927 = arith.index_cast %dim_3124 : index to i64 + %8928 = arith.maxsi %8926, %8927 : i64 + %8929 = arith.index_cast %8928 : i64 to index + %from_elements_3125 = tensor.from_elements %8929, %c4096 : tensor<2xindex> + %8930 = stablehlo.dynamic_broadcast_in_dim %8885, %from_elements_3125, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3126 = tensor.dim %8930, %c0 : tensor + %8931 = arith.index_cast %dim_3126 : index to i64 + %from_elements_3127 = tensor.from_elements %8931, %c4096_i64 : tensor<2xi64> + %8932 = stablehlo.real_dynamic_slice %8925, %c_22, %from_elements_3127, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3128 = tensor.from_elements %8931, %c4096_i64, %c1_i64 : tensor<3xi64> + %8933 = stablehlo.dynamic_reshape %8930, %from_elements_3128 : (tensor, tensor<3xi64>) -> tensor + %8934 = stablehlo.dynamic_iota %from_elements_3128, dim = 1 : (tensor<3xi64>) -> tensor + %8935 = stablehlo.concatenate %8933, %8934, dim = 2 : (tensor, tensor) -> tensor + %8936 = "stablehlo.scatter"(%8873, %8935, %8932) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %8937 = stablehlo.slice %8745 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %8938 = stablehlo.reshape %8937 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %8939 = stablehlo.custom_call @byteir.non_zero(%8938) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3129 = tensor.dim %8939, %c0 : tensor + %8940 = arith.index_cast %dim_3129 : index to i64 + %from_elements_3130 = tensor.from_elements %8940, %c1_i64 : tensor<2xi64> + %8941 = stablehlo.real_dynamic_slice %8939, %c_22, %from_elements_3130, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3131 = tensor.dim %8941, %c0 : tensor + %8942 = arith.index_cast %dim_3131 : index to i64 + %from_elements_3132 = tensor.from_elements %8942 : tensor<1xi64> + %8943 = stablehlo.dynamic_reshape %8941, %from_elements_3132 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3133 = tensor.from_elements %8940, %c2_i64 : tensor<2xi64> + %8944 = stablehlo.real_dynamic_slice %8939, %c_24, %from_elements_3133, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3134 = tensor.dim %8944, %c0 : tensor + %8945 = arith.index_cast %dim_3134 : index to i64 + %from_elements_3135 = tensor.from_elements %8945 : tensor<1xi64> + %8946 = stablehlo.dynamic_reshape %8944, %from_elements_3135 : (tensor, tensor<1xi64>) -> tensor + %dim_3136 = tensor.dim %8946, %c0 : tensor + %8947 = arith.index_cast %dim_3136 : index to i64 + %from_elements_3137 = tensor.from_elements %8947, %c1_i64 : tensor<2xi64> + %8948 = stablehlo.dynamic_reshape %8946, %from_elements_3137 : (tensor, tensor<2xi64>) -> tensor + %dim_3138 = tensor.dim %8948, %c0 : tensor + %8949 = arith.index_cast %dim_3138 : index to i64 + %from_elements_3139 = tensor.from_elements %c1_i64, %8949, %c4096_i64 : tensor<3xi64> + %8950 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3139, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3140 = tensor.dim %8950, %c1 : tensor<1x?x4096xi64> + %8951 = arith.index_cast %dim_3140 : index to i64 + %from_elements_3141 = tensor.from_elements %c1_i64, %8951, %c4096_i64, %c1_i64 : tensor<4xi64> + %8952 = stablehlo.dynamic_reshape %8950, %from_elements_3141 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8953 = stablehlo.dynamic_broadcast_in_dim %8948, %from_elements_3139, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3142 = tensor.dim %8953, %c1 : tensor<1x?x4096xi64> + %8954 = arith.index_cast %dim_3142 : index to i64 + %from_elements_3143 = tensor.from_elements %c1_i64, %8954, %c4096_i64, %c1_i64 : tensor<4xi64> + %8955 = stablehlo.dynamic_reshape %8953, %from_elements_3143 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8956 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3139, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3144 = tensor.dim %8956, %c1 : tensor<1x?x4096xi64> + %8957 = arith.index_cast %dim_3144 : index to i64 + %from_elements_3145 = tensor.from_elements %c1_i64, %8957, %c4096_i64, %c1_i64 : tensor<4xi64> + %8958 = stablehlo.dynamic_reshape %8956, %from_elements_3145 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %8959 = stablehlo.concatenate %8952, %8955, %8958, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %8960 = "stablehlo.gather"(%8756, %8959) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %8961 = shape.shape_of %8960 : tensor<1x?x4096xf32> -> tensor<3xindex> + %8962 = shape.num_elements %8961 : tensor<3xindex> -> index + %8963 = stablehlo.compute_reshape_shape %8962, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %8964 = stablehlo.dynamic_reshape %8960, %8963 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %8965 = stablehlo.dot %8964, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %8966 = stablehlo.logistic %8965 : tensor + %8967 = shape.shape_of %8966 : tensor -> tensor<2xindex> + %8968 = shape.shape_of %8965 : tensor -> tensor<2xindex> + %8969 = shape.cstr_broadcastable %8967, %8968 : tensor<2xindex>, tensor<2xindex> + %8970 = shape.assuming %8969 -> (tensor) { + %19688 = shape.broadcast %8967, %8968 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8966, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8965, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8971 = shape.shape_of %8970 : tensor -> tensor<2xindex> + %8972 = shape.cstr_broadcastable %8971, %8968 : tensor<2xindex>, tensor<2xindex> + %8973 = shape.assuming %8972 -> (tensor) { + %19688 = shape.broadcast %8971, %8968 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8970, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8965, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8974 = stablehlo.dot %8973, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3146 = tensor.dim %8946, %c0 : tensor + %8975 = arith.index_cast %dim_3146 : index to i64 + %from_elements_3147 = tensor.from_elements %8975, %c1_i64 : tensor<2xi64> + %8976 = stablehlo.dynamic_reshape %8946, %from_elements_3147 : (tensor, tensor<2xi64>) -> tensor + %dim_3148 = tensor.dim %8943, %c0 : tensor + %8977 = arith.index_cast %dim_3148 : index to i64 + %from_elements_3149 = tensor.from_elements %8977, %c1_i64 : tensor<2xi64> + %8978 = stablehlo.dynamic_reshape %8943, %from_elements_3149 : (tensor, tensor<2xi64>) -> tensor + %8979 = stablehlo.concatenate %8976, %8978, dim = 1 : (tensor, tensor) -> tensor + %8980 = "stablehlo.gather"(%8785, %8979) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %8981 = shape.shape_of %8974 : tensor -> tensor<2xindex> + %8982 = shape.shape_of %8980 : tensor -> tensor<2xindex> + %8983 = shape.cstr_broadcastable %8981, %8982 : tensor<2xindex>, tensor<2xindex> + %8984 = shape.assuming %8983 -> (tensor) { + %19688 = shape.broadcast %8981, %8982 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %8974, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %8980, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %8985 = shape.shape_of %8984 : tensor -> tensor<2xindex> + %8986 = stablehlo.dynamic_broadcast_in_dim %8984, %8985, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %8987 = stablehlo.dynamic_broadcast_in_dim %213, %8985, dims = [] : (tensor, tensor<2xindex>) -> tensor + %8988 = stablehlo.multiply %8986, %8987 : tensor + %dim_3150 = tensor.dim %8948, %c0 : tensor + %8989 = arith.index_cast %dim_3150 : index to i64 + %dim_3151 = tensor.dim %8984, %c0 : tensor + %8990 = arith.index_cast %dim_3151 : index to i64 + %8991 = arith.maxsi %8989, %8990 : i64 + %8992 = arith.index_cast %8991 : i64 to index + %from_elements_3152 = tensor.from_elements %8992, %c4096 : tensor<2xindex> + %8993 = stablehlo.dynamic_broadcast_in_dim %8948, %from_elements_3152, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3153 = tensor.dim %8993, %c0 : tensor + %8994 = arith.index_cast %dim_3153 : index to i64 + %from_elements_3154 = tensor.from_elements %8994, %c4096_i64 : tensor<2xi64> + %8995 = stablehlo.real_dynamic_slice %8988, %c_22, %from_elements_3154, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3155 = tensor.from_elements %8994, %c4096_i64, %c1_i64 : tensor<3xi64> + %8996 = stablehlo.dynamic_reshape %8993, %from_elements_3155 : (tensor, tensor<3xi64>) -> tensor + %8997 = stablehlo.dynamic_iota %from_elements_3155, dim = 1 : (tensor<3xi64>) -> tensor + %8998 = stablehlo.concatenate %8996, %8997, dim = 2 : (tensor, tensor) -> tensor + %8999 = "stablehlo.scatter"(%8936, %8998, %8995) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9000 = stablehlo.slice %8745 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9001 = stablehlo.reshape %9000 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9002 = stablehlo.custom_call @byteir.non_zero(%9001) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3156 = tensor.dim %9002, %c0 : tensor + %9003 = arith.index_cast %dim_3156 : index to i64 + %from_elements_3157 = tensor.from_elements %9003, %c1_i64 : tensor<2xi64> + %9004 = stablehlo.real_dynamic_slice %9002, %c_22, %from_elements_3157, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3158 = tensor.dim %9004, %c0 : tensor + %9005 = arith.index_cast %dim_3158 : index to i64 + %from_elements_3159 = tensor.from_elements %9005 : tensor<1xi64> + %9006 = stablehlo.dynamic_reshape %9004, %from_elements_3159 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3160 = tensor.from_elements %9003, %c2_i64 : tensor<2xi64> + %9007 = stablehlo.real_dynamic_slice %9002, %c_24, %from_elements_3160, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3161 = tensor.dim %9007, %c0 : tensor + %9008 = arith.index_cast %dim_3161 : index to i64 + %from_elements_3162 = tensor.from_elements %9008 : tensor<1xi64> + %9009 = stablehlo.dynamic_reshape %9007, %from_elements_3162 : (tensor, tensor<1xi64>) -> tensor + %dim_3163 = tensor.dim %9009, %c0 : tensor + %9010 = arith.index_cast %dim_3163 : index to i64 + %from_elements_3164 = tensor.from_elements %9010, %c1_i64 : tensor<2xi64> + %9011 = stablehlo.dynamic_reshape %9009, %from_elements_3164 : (tensor, tensor<2xi64>) -> tensor + %dim_3165 = tensor.dim %9011, %c0 : tensor + %9012 = arith.index_cast %dim_3165 : index to i64 + %from_elements_3166 = tensor.from_elements %c1_i64, %9012, %c4096_i64 : tensor<3xi64> + %9013 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3166, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3167 = tensor.dim %9013, %c1 : tensor<1x?x4096xi64> + %9014 = arith.index_cast %dim_3167 : index to i64 + %from_elements_3168 = tensor.from_elements %c1_i64, %9014, %c4096_i64, %c1_i64 : tensor<4xi64> + %9015 = stablehlo.dynamic_reshape %9013, %from_elements_3168 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9016 = stablehlo.dynamic_broadcast_in_dim %9011, %from_elements_3166, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3169 = tensor.dim %9016, %c1 : tensor<1x?x4096xi64> + %9017 = arith.index_cast %dim_3169 : index to i64 + %from_elements_3170 = tensor.from_elements %c1_i64, %9017, %c4096_i64, %c1_i64 : tensor<4xi64> + %9018 = stablehlo.dynamic_reshape %9016, %from_elements_3170 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9019 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3166, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3171 = tensor.dim %9019, %c1 : tensor<1x?x4096xi64> + %9020 = arith.index_cast %dim_3171 : index to i64 + %from_elements_3172 = tensor.from_elements %c1_i64, %9020, %c4096_i64, %c1_i64 : tensor<4xi64> + %9021 = stablehlo.dynamic_reshape %9019, %from_elements_3172 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9022 = stablehlo.concatenate %9015, %9018, %9021, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9023 = "stablehlo.gather"(%8756, %9022) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9024 = shape.shape_of %9023 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9025 = shape.num_elements %9024 : tensor<3xindex> -> index + %9026 = stablehlo.compute_reshape_shape %9025, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9027 = stablehlo.dynamic_reshape %9023, %9026 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9028 = stablehlo.dot %9027, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9029 = stablehlo.logistic %9028 : tensor + %9030 = shape.shape_of %9029 : tensor -> tensor<2xindex> + %9031 = shape.shape_of %9028 : tensor -> tensor<2xindex> + %9032 = shape.cstr_broadcastable %9030, %9031 : tensor<2xindex>, tensor<2xindex> + %9033 = shape.assuming %9032 -> (tensor) { + %19688 = shape.broadcast %9030, %9031 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9029, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9028, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9034 = shape.shape_of %9033 : tensor -> tensor<2xindex> + %9035 = shape.cstr_broadcastable %9034, %9031 : tensor<2xindex>, tensor<2xindex> + %9036 = shape.assuming %9035 -> (tensor) { + %19688 = shape.broadcast %9034, %9031 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9033, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9028, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9037 = stablehlo.dot %9036, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3173 = tensor.dim %9009, %c0 : tensor + %9038 = arith.index_cast %dim_3173 : index to i64 + %from_elements_3174 = tensor.from_elements %9038, %c1_i64 : tensor<2xi64> + %9039 = stablehlo.dynamic_reshape %9009, %from_elements_3174 : (tensor, tensor<2xi64>) -> tensor + %dim_3175 = tensor.dim %9006, %c0 : tensor + %9040 = arith.index_cast %dim_3175 : index to i64 + %from_elements_3176 = tensor.from_elements %9040, %c1_i64 : tensor<2xi64> + %9041 = stablehlo.dynamic_reshape %9006, %from_elements_3176 : (tensor, tensor<2xi64>) -> tensor + %9042 = stablehlo.concatenate %9039, %9041, dim = 1 : (tensor, tensor) -> tensor + %9043 = "stablehlo.gather"(%8785, %9042) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9044 = shape.shape_of %9037 : tensor -> tensor<2xindex> + %9045 = shape.shape_of %9043 : tensor -> tensor<2xindex> + %9046 = shape.cstr_broadcastable %9044, %9045 : tensor<2xindex>, tensor<2xindex> + %9047 = shape.assuming %9046 -> (tensor) { + %19688 = shape.broadcast %9044, %9045 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9037, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9043, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9048 = shape.shape_of %9047 : tensor -> tensor<2xindex> + %9049 = stablehlo.dynamic_broadcast_in_dim %9047, %9048, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9050 = stablehlo.dynamic_broadcast_in_dim %213, %9048, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9051 = stablehlo.multiply %9049, %9050 : tensor + %dim_3177 = tensor.dim %9011, %c0 : tensor + %9052 = arith.index_cast %dim_3177 : index to i64 + %dim_3178 = tensor.dim %9047, %c0 : tensor + %9053 = arith.index_cast %dim_3178 : index to i64 + %9054 = arith.maxsi %9052, %9053 : i64 + %9055 = arith.index_cast %9054 : i64 to index + %from_elements_3179 = tensor.from_elements %9055, %c4096 : tensor<2xindex> + %9056 = stablehlo.dynamic_broadcast_in_dim %9011, %from_elements_3179, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3180 = tensor.dim %9056, %c0 : tensor + %9057 = arith.index_cast %dim_3180 : index to i64 + %from_elements_3181 = tensor.from_elements %9057, %c4096_i64 : tensor<2xi64> + %9058 = stablehlo.real_dynamic_slice %9051, %c_22, %from_elements_3181, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3182 = tensor.from_elements %9057, %c4096_i64, %c1_i64 : tensor<3xi64> + %9059 = stablehlo.dynamic_reshape %9056, %from_elements_3182 : (tensor, tensor<3xi64>) -> tensor + %9060 = stablehlo.dynamic_iota %from_elements_3182, dim = 1 : (tensor<3xi64>) -> tensor + %9061 = stablehlo.concatenate %9059, %9060, dim = 2 : (tensor, tensor) -> tensor + %9062 = "stablehlo.scatter"(%8999, %9061, %9058) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9063 = stablehlo.slice %8745 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9064 = stablehlo.reshape %9063 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9065 = stablehlo.custom_call @byteir.non_zero(%9064) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3183 = tensor.dim %9065, %c0 : tensor + %9066 = arith.index_cast %dim_3183 : index to i64 + %from_elements_3184 = tensor.from_elements %9066, %c1_i64 : tensor<2xi64> + %9067 = stablehlo.real_dynamic_slice %9065, %c_22, %from_elements_3184, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3185 = tensor.dim %9067, %c0 : tensor + %9068 = arith.index_cast %dim_3185 : index to i64 + %from_elements_3186 = tensor.from_elements %9068 : tensor<1xi64> + %9069 = stablehlo.dynamic_reshape %9067, %from_elements_3186 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3187 = tensor.from_elements %9066, %c2_i64 : tensor<2xi64> + %9070 = stablehlo.real_dynamic_slice %9065, %c_24, %from_elements_3187, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3188 = tensor.dim %9070, %c0 : tensor + %9071 = arith.index_cast %dim_3188 : index to i64 + %from_elements_3189 = tensor.from_elements %9071 : tensor<1xi64> + %9072 = stablehlo.dynamic_reshape %9070, %from_elements_3189 : (tensor, tensor<1xi64>) -> tensor + %dim_3190 = tensor.dim %9072, %c0 : tensor + %9073 = arith.index_cast %dim_3190 : index to i64 + %from_elements_3191 = tensor.from_elements %9073, %c1_i64 : tensor<2xi64> + %9074 = stablehlo.dynamic_reshape %9072, %from_elements_3191 : (tensor, tensor<2xi64>) -> tensor + %dim_3192 = tensor.dim %9074, %c0 : tensor + %9075 = arith.index_cast %dim_3192 : index to i64 + %from_elements_3193 = tensor.from_elements %c1_i64, %9075, %c4096_i64 : tensor<3xi64> + %9076 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3193, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3194 = tensor.dim %9076, %c1 : tensor<1x?x4096xi64> + %9077 = arith.index_cast %dim_3194 : index to i64 + %from_elements_3195 = tensor.from_elements %c1_i64, %9077, %c4096_i64, %c1_i64 : tensor<4xi64> + %9078 = stablehlo.dynamic_reshape %9076, %from_elements_3195 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9079 = stablehlo.dynamic_broadcast_in_dim %9074, %from_elements_3193, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3196 = tensor.dim %9079, %c1 : tensor<1x?x4096xi64> + %9080 = arith.index_cast %dim_3196 : index to i64 + %from_elements_3197 = tensor.from_elements %c1_i64, %9080, %c4096_i64, %c1_i64 : tensor<4xi64> + %9081 = stablehlo.dynamic_reshape %9079, %from_elements_3197 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9082 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3193, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3198 = tensor.dim %9082, %c1 : tensor<1x?x4096xi64> + %9083 = arith.index_cast %dim_3198 : index to i64 + %from_elements_3199 = tensor.from_elements %c1_i64, %9083, %c4096_i64, %c1_i64 : tensor<4xi64> + %9084 = stablehlo.dynamic_reshape %9082, %from_elements_3199 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9085 = stablehlo.concatenate %9078, %9081, %9084, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9086 = "stablehlo.gather"(%8756, %9085) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9087 = shape.shape_of %9086 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9088 = shape.num_elements %9087 : tensor<3xindex> -> index + %9089 = stablehlo.compute_reshape_shape %9088, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9090 = stablehlo.dynamic_reshape %9086, %9089 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9091 = stablehlo.dot %9090, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9092 = stablehlo.logistic %9091 : tensor + %9093 = shape.shape_of %9092 : tensor -> tensor<2xindex> + %9094 = shape.shape_of %9091 : tensor -> tensor<2xindex> + %9095 = shape.cstr_broadcastable %9093, %9094 : tensor<2xindex>, tensor<2xindex> + %9096 = shape.assuming %9095 -> (tensor) { + %19688 = shape.broadcast %9093, %9094 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9092, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9091, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9097 = shape.shape_of %9096 : tensor -> tensor<2xindex> + %9098 = shape.cstr_broadcastable %9097, %9094 : tensor<2xindex>, tensor<2xindex> + %9099 = shape.assuming %9098 -> (tensor) { + %19688 = shape.broadcast %9097, %9094 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9096, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9091, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9100 = stablehlo.dot %9099, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3200 = tensor.dim %9072, %c0 : tensor + %9101 = arith.index_cast %dim_3200 : index to i64 + %from_elements_3201 = tensor.from_elements %9101, %c1_i64 : tensor<2xi64> + %9102 = stablehlo.dynamic_reshape %9072, %from_elements_3201 : (tensor, tensor<2xi64>) -> tensor + %dim_3202 = tensor.dim %9069, %c0 : tensor + %9103 = arith.index_cast %dim_3202 : index to i64 + %from_elements_3203 = tensor.from_elements %9103, %c1_i64 : tensor<2xi64> + %9104 = stablehlo.dynamic_reshape %9069, %from_elements_3203 : (tensor, tensor<2xi64>) -> tensor + %9105 = stablehlo.concatenate %9102, %9104, dim = 1 : (tensor, tensor) -> tensor + %9106 = "stablehlo.gather"(%8785, %9105) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9107 = shape.shape_of %9100 : tensor -> tensor<2xindex> + %9108 = shape.shape_of %9106 : tensor -> tensor<2xindex> + %9109 = shape.cstr_broadcastable %9107, %9108 : tensor<2xindex>, tensor<2xindex> + %9110 = shape.assuming %9109 -> (tensor) { + %19688 = shape.broadcast %9107, %9108 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9100, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9106, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9111 = shape.shape_of %9110 : tensor -> tensor<2xindex> + %9112 = stablehlo.dynamic_broadcast_in_dim %9110, %9111, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9113 = stablehlo.dynamic_broadcast_in_dim %213, %9111, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9114 = stablehlo.multiply %9112, %9113 : tensor + %dim_3204 = tensor.dim %9074, %c0 : tensor + %9115 = arith.index_cast %dim_3204 : index to i64 + %dim_3205 = tensor.dim %9110, %c0 : tensor + %9116 = arith.index_cast %dim_3205 : index to i64 + %9117 = arith.maxsi %9115, %9116 : i64 + %9118 = arith.index_cast %9117 : i64 to index + %from_elements_3206 = tensor.from_elements %9118, %c4096 : tensor<2xindex> + %9119 = stablehlo.dynamic_broadcast_in_dim %9074, %from_elements_3206, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3207 = tensor.dim %9119, %c0 : tensor + %9120 = arith.index_cast %dim_3207 : index to i64 + %from_elements_3208 = tensor.from_elements %9120, %c4096_i64 : tensor<2xi64> + %9121 = stablehlo.real_dynamic_slice %9114, %c_22, %from_elements_3208, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3209 = tensor.from_elements %9120, %c4096_i64, %c1_i64 : tensor<3xi64> + %9122 = stablehlo.dynamic_reshape %9119, %from_elements_3209 : (tensor, tensor<3xi64>) -> tensor + %9123 = stablehlo.dynamic_iota %from_elements_3209, dim = 1 : (tensor<3xi64>) -> tensor + %9124 = stablehlo.concatenate %9122, %9123, dim = 2 : (tensor, tensor) -> tensor + %9125 = "stablehlo.scatter"(%9062, %9124, %9121) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9126 = stablehlo.slice %8745 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9127 = stablehlo.reshape %9126 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9128 = stablehlo.custom_call @byteir.non_zero(%9127) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3210 = tensor.dim %9128, %c0 : tensor + %9129 = arith.index_cast %dim_3210 : index to i64 + %from_elements_3211 = tensor.from_elements %9129, %c1_i64 : tensor<2xi64> + %9130 = stablehlo.real_dynamic_slice %9128, %c_22, %from_elements_3211, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3212 = tensor.dim %9130, %c0 : tensor + %9131 = arith.index_cast %dim_3212 : index to i64 + %from_elements_3213 = tensor.from_elements %9131 : tensor<1xi64> + %9132 = stablehlo.dynamic_reshape %9130, %from_elements_3213 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3214 = tensor.from_elements %9129, %c2_i64 : tensor<2xi64> + %9133 = stablehlo.real_dynamic_slice %9128, %c_24, %from_elements_3214, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3215 = tensor.dim %9133, %c0 : tensor + %9134 = arith.index_cast %dim_3215 : index to i64 + %from_elements_3216 = tensor.from_elements %9134 : tensor<1xi64> + %9135 = stablehlo.dynamic_reshape %9133, %from_elements_3216 : (tensor, tensor<1xi64>) -> tensor + %dim_3217 = tensor.dim %9135, %c0 : tensor + %9136 = arith.index_cast %dim_3217 : index to i64 + %from_elements_3218 = tensor.from_elements %9136, %c1_i64 : tensor<2xi64> + %9137 = stablehlo.dynamic_reshape %9135, %from_elements_3218 : (tensor, tensor<2xi64>) -> tensor + %dim_3219 = tensor.dim %9137, %c0 : tensor + %9138 = arith.index_cast %dim_3219 : index to i64 + %from_elements_3220 = tensor.from_elements %c1_i64, %9138, %c4096_i64 : tensor<3xi64> + %9139 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3220, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3221 = tensor.dim %9139, %c1 : tensor<1x?x4096xi64> + %9140 = arith.index_cast %dim_3221 : index to i64 + %from_elements_3222 = tensor.from_elements %c1_i64, %9140, %c4096_i64, %c1_i64 : tensor<4xi64> + %9141 = stablehlo.dynamic_reshape %9139, %from_elements_3222 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9142 = stablehlo.dynamic_broadcast_in_dim %9137, %from_elements_3220, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3223 = tensor.dim %9142, %c1 : tensor<1x?x4096xi64> + %9143 = arith.index_cast %dim_3223 : index to i64 + %from_elements_3224 = tensor.from_elements %c1_i64, %9143, %c4096_i64, %c1_i64 : tensor<4xi64> + %9144 = stablehlo.dynamic_reshape %9142, %from_elements_3224 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9145 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3220, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3225 = tensor.dim %9145, %c1 : tensor<1x?x4096xi64> + %9146 = arith.index_cast %dim_3225 : index to i64 + %from_elements_3226 = tensor.from_elements %c1_i64, %9146, %c4096_i64, %c1_i64 : tensor<4xi64> + %9147 = stablehlo.dynamic_reshape %9145, %from_elements_3226 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9148 = stablehlo.concatenate %9141, %9144, %9147, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9149 = "stablehlo.gather"(%8756, %9148) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9150 = shape.shape_of %9149 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9151 = shape.num_elements %9150 : tensor<3xindex> -> index + %9152 = stablehlo.compute_reshape_shape %9151, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9153 = stablehlo.dynamic_reshape %9149, %9152 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9154 = stablehlo.dot %9153, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9155 = stablehlo.logistic %9154 : tensor + %9156 = shape.shape_of %9155 : tensor -> tensor<2xindex> + %9157 = shape.shape_of %9154 : tensor -> tensor<2xindex> + %9158 = shape.cstr_broadcastable %9156, %9157 : tensor<2xindex>, tensor<2xindex> + %9159 = shape.assuming %9158 -> (tensor) { + %19688 = shape.broadcast %9156, %9157 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9155, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9154, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9160 = shape.shape_of %9159 : tensor -> tensor<2xindex> + %9161 = shape.cstr_broadcastable %9160, %9157 : tensor<2xindex>, tensor<2xindex> + %9162 = shape.assuming %9161 -> (tensor) { + %19688 = shape.broadcast %9160, %9157 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9159, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9154, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9163 = stablehlo.dot %9162, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3227 = tensor.dim %9135, %c0 : tensor + %9164 = arith.index_cast %dim_3227 : index to i64 + %from_elements_3228 = tensor.from_elements %9164, %c1_i64 : tensor<2xi64> + %9165 = stablehlo.dynamic_reshape %9135, %from_elements_3228 : (tensor, tensor<2xi64>) -> tensor + %dim_3229 = tensor.dim %9132, %c0 : tensor + %9166 = arith.index_cast %dim_3229 : index to i64 + %from_elements_3230 = tensor.from_elements %9166, %c1_i64 : tensor<2xi64> + %9167 = stablehlo.dynamic_reshape %9132, %from_elements_3230 : (tensor, tensor<2xi64>) -> tensor + %9168 = stablehlo.concatenate %9165, %9167, dim = 1 : (tensor, tensor) -> tensor + %9169 = "stablehlo.gather"(%8785, %9168) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9170 = shape.shape_of %9163 : tensor -> tensor<2xindex> + %9171 = shape.shape_of %9169 : tensor -> tensor<2xindex> + %9172 = shape.cstr_broadcastable %9170, %9171 : tensor<2xindex>, tensor<2xindex> + %9173 = shape.assuming %9172 -> (tensor) { + %19688 = shape.broadcast %9170, %9171 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9163, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9169, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9174 = shape.shape_of %9173 : tensor -> tensor<2xindex> + %9175 = stablehlo.dynamic_broadcast_in_dim %9173, %9174, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9176 = stablehlo.dynamic_broadcast_in_dim %213, %9174, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9177 = stablehlo.multiply %9175, %9176 : tensor + %dim_3231 = tensor.dim %9137, %c0 : tensor + %9178 = arith.index_cast %dim_3231 : index to i64 + %dim_3232 = tensor.dim %9173, %c0 : tensor + %9179 = arith.index_cast %dim_3232 : index to i64 + %9180 = arith.maxsi %9178, %9179 : i64 + %9181 = arith.index_cast %9180 : i64 to index + %from_elements_3233 = tensor.from_elements %9181, %c4096 : tensor<2xindex> + %9182 = stablehlo.dynamic_broadcast_in_dim %9137, %from_elements_3233, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3234 = tensor.dim %9182, %c0 : tensor + %9183 = arith.index_cast %dim_3234 : index to i64 + %from_elements_3235 = tensor.from_elements %9183, %c4096_i64 : tensor<2xi64> + %9184 = stablehlo.real_dynamic_slice %9177, %c_22, %from_elements_3235, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3236 = tensor.from_elements %9183, %c4096_i64, %c1_i64 : tensor<3xi64> + %9185 = stablehlo.dynamic_reshape %9182, %from_elements_3236 : (tensor, tensor<3xi64>) -> tensor + %9186 = stablehlo.dynamic_iota %from_elements_3236, dim = 1 : (tensor<3xi64>) -> tensor + %9187 = stablehlo.concatenate %9185, %9186, dim = 2 : (tensor, tensor) -> tensor + %9188 = "stablehlo.scatter"(%9125, %9187, %9184) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9189 = stablehlo.slice %8745 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9190 = stablehlo.reshape %9189 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9191 = stablehlo.custom_call @byteir.non_zero(%9190) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3237 = tensor.dim %9191, %c0 : tensor + %9192 = arith.index_cast %dim_3237 : index to i64 + %from_elements_3238 = tensor.from_elements %9192, %c1_i64 : tensor<2xi64> + %9193 = stablehlo.real_dynamic_slice %9191, %c_22, %from_elements_3238, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3239 = tensor.dim %9193, %c0 : tensor + %9194 = arith.index_cast %dim_3239 : index to i64 + %from_elements_3240 = tensor.from_elements %9194 : tensor<1xi64> + %9195 = stablehlo.dynamic_reshape %9193, %from_elements_3240 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3241 = tensor.from_elements %9192, %c2_i64 : tensor<2xi64> + %9196 = stablehlo.real_dynamic_slice %9191, %c_24, %from_elements_3241, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3242 = tensor.dim %9196, %c0 : tensor + %9197 = arith.index_cast %dim_3242 : index to i64 + %from_elements_3243 = tensor.from_elements %9197 : tensor<1xi64> + %9198 = stablehlo.dynamic_reshape %9196, %from_elements_3243 : (tensor, tensor<1xi64>) -> tensor + %dim_3244 = tensor.dim %9198, %c0 : tensor + %9199 = arith.index_cast %dim_3244 : index to i64 + %from_elements_3245 = tensor.from_elements %9199, %c1_i64 : tensor<2xi64> + %9200 = stablehlo.dynamic_reshape %9198, %from_elements_3245 : (tensor, tensor<2xi64>) -> tensor + %dim_3246 = tensor.dim %9200, %c0 : tensor + %9201 = arith.index_cast %dim_3246 : index to i64 + %from_elements_3247 = tensor.from_elements %c1_i64, %9201, %c4096_i64 : tensor<3xi64> + %9202 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3247, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3248 = tensor.dim %9202, %c1 : tensor<1x?x4096xi64> + %9203 = arith.index_cast %dim_3248 : index to i64 + %from_elements_3249 = tensor.from_elements %c1_i64, %9203, %c4096_i64, %c1_i64 : tensor<4xi64> + %9204 = stablehlo.dynamic_reshape %9202, %from_elements_3249 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9205 = stablehlo.dynamic_broadcast_in_dim %9200, %from_elements_3247, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3250 = tensor.dim %9205, %c1 : tensor<1x?x4096xi64> + %9206 = arith.index_cast %dim_3250 : index to i64 + %from_elements_3251 = tensor.from_elements %c1_i64, %9206, %c4096_i64, %c1_i64 : tensor<4xi64> + %9207 = stablehlo.dynamic_reshape %9205, %from_elements_3251 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9208 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3247, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3252 = tensor.dim %9208, %c1 : tensor<1x?x4096xi64> + %9209 = arith.index_cast %dim_3252 : index to i64 + %from_elements_3253 = tensor.from_elements %c1_i64, %9209, %c4096_i64, %c1_i64 : tensor<4xi64> + %9210 = stablehlo.dynamic_reshape %9208, %from_elements_3253 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9211 = stablehlo.concatenate %9204, %9207, %9210, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9212 = "stablehlo.gather"(%8756, %9211) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9213 = shape.shape_of %9212 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9214 = shape.num_elements %9213 : tensor<3xindex> -> index + %9215 = stablehlo.compute_reshape_shape %9214, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9216 = stablehlo.dynamic_reshape %9212, %9215 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9217 = stablehlo.dot %9216, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9218 = stablehlo.logistic %9217 : tensor + %9219 = shape.shape_of %9218 : tensor -> tensor<2xindex> + %9220 = shape.shape_of %9217 : tensor -> tensor<2xindex> + %9221 = shape.cstr_broadcastable %9219, %9220 : tensor<2xindex>, tensor<2xindex> + %9222 = shape.assuming %9221 -> (tensor) { + %19688 = shape.broadcast %9219, %9220 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9218, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9217, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9223 = shape.shape_of %9222 : tensor -> tensor<2xindex> + %9224 = shape.cstr_broadcastable %9223, %9220 : tensor<2xindex>, tensor<2xindex> + %9225 = shape.assuming %9224 -> (tensor) { + %19688 = shape.broadcast %9223, %9220 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9222, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9217, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9226 = stablehlo.dot %9225, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3254 = tensor.dim %9198, %c0 : tensor + %9227 = arith.index_cast %dim_3254 : index to i64 + %from_elements_3255 = tensor.from_elements %9227, %c1_i64 : tensor<2xi64> + %9228 = stablehlo.dynamic_reshape %9198, %from_elements_3255 : (tensor, tensor<2xi64>) -> tensor + %dim_3256 = tensor.dim %9195, %c0 : tensor + %9229 = arith.index_cast %dim_3256 : index to i64 + %from_elements_3257 = tensor.from_elements %9229, %c1_i64 : tensor<2xi64> + %9230 = stablehlo.dynamic_reshape %9195, %from_elements_3257 : (tensor, tensor<2xi64>) -> tensor + %9231 = stablehlo.concatenate %9228, %9230, dim = 1 : (tensor, tensor) -> tensor + %9232 = "stablehlo.gather"(%8785, %9231) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9233 = shape.shape_of %9226 : tensor -> tensor<2xindex> + %9234 = shape.shape_of %9232 : tensor -> tensor<2xindex> + %9235 = shape.cstr_broadcastable %9233, %9234 : tensor<2xindex>, tensor<2xindex> + %9236 = shape.assuming %9235 -> (tensor) { + %19688 = shape.broadcast %9233, %9234 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9226, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9232, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9237 = shape.shape_of %9236 : tensor -> tensor<2xindex> + %9238 = stablehlo.dynamic_broadcast_in_dim %9236, %9237, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9239 = stablehlo.dynamic_broadcast_in_dim %213, %9237, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9240 = stablehlo.multiply %9238, %9239 : tensor + %dim_3258 = tensor.dim %9200, %c0 : tensor + %9241 = arith.index_cast %dim_3258 : index to i64 + %dim_3259 = tensor.dim %9236, %c0 : tensor + %9242 = arith.index_cast %dim_3259 : index to i64 + %9243 = arith.maxsi %9241, %9242 : i64 + %9244 = arith.index_cast %9243 : i64 to index + %from_elements_3260 = tensor.from_elements %9244, %c4096 : tensor<2xindex> + %9245 = stablehlo.dynamic_broadcast_in_dim %9200, %from_elements_3260, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3261 = tensor.dim %9245, %c0 : tensor + %9246 = arith.index_cast %dim_3261 : index to i64 + %from_elements_3262 = tensor.from_elements %9246, %c4096_i64 : tensor<2xi64> + %9247 = stablehlo.real_dynamic_slice %9240, %c_22, %from_elements_3262, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3263 = tensor.from_elements %9246, %c4096_i64, %c1_i64 : tensor<3xi64> + %9248 = stablehlo.dynamic_reshape %9245, %from_elements_3263 : (tensor, tensor<3xi64>) -> tensor + %9249 = stablehlo.dynamic_iota %from_elements_3263, dim = 1 : (tensor<3xi64>) -> tensor + %9250 = stablehlo.concatenate %9248, %9249, dim = 2 : (tensor, tensor) -> tensor + %9251 = "stablehlo.scatter"(%9188, %9250, %9247) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9252 = stablehlo.reshape %9251 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %9253 = stablehlo.add %8718, %9252 : tensor<3x1x4096xf32> + %9254 = stablehlo.broadcast_in_dim %9253, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %9255 = stablehlo.power %9254, %15 : tensor<3x1x4096xf32> + %9256 = stablehlo.reduce(%9255 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %9257 = stablehlo.reshape %9256 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %9258 = stablehlo.broadcast_in_dim %9257, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %9259 = stablehlo.divide %9258, %21 : tensor<3x1x1xf32> + %9260 = stablehlo.broadcast_in_dim %9259, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %9261 = stablehlo.add %9260, %25 : tensor<3x1x1xf32> + %9262 = stablehlo.rsqrt %9261 : tensor<3x1x1xf32> + %9263 = stablehlo.broadcast_in_dim %9262, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %9264 = stablehlo.multiply %9254, %9263 : tensor<3x1x4096xf32> + %9265 = stablehlo.broadcast_in_dim %9264, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %9266 = stablehlo.multiply %9265, %31 : tensor<3x1x4096xf32> + %9267 = stablehlo.reshape %9266 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %9268 = stablehlo.dot %9267, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %9269 = stablehlo.reshape %9268 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %9270 = stablehlo.dot %9267, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %9271 = stablehlo.reshape %9270 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %9272 = stablehlo.reshape %9269 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %9273 = stablehlo.transpose %9272, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %9274 = stablehlo.reshape %9271 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %9275 = stablehlo.transpose %9274, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %9276 = stablehlo.slice %arg30 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %9277 = stablehlo.slice %arg31 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %9278 = "stablehlo.gather"(%9276, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %9279 = stablehlo.reshape %9278 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %9280 = "stablehlo.gather"(%9277, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %9281 = stablehlo.reshape %9280 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %9282 = stablehlo.broadcast_in_dim %9273, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %9283 = stablehlo.broadcast_in_dim %9279, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %9284 = stablehlo.multiply %9282, %9283 : tensor<3x32x1x128xf32> + %9285 = stablehlo.slice %9273 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %9286 = stablehlo.slice %9273 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %9287 = stablehlo.negate %9286 : tensor<3x32x1x64xf32> + %9288 = stablehlo.concatenate %9287, %9285, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %9289 = stablehlo.broadcast_in_dim %9288, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %9290 = stablehlo.broadcast_in_dim %9281, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %9291 = stablehlo.multiply %9289, %9290 : tensor<3x32x1x128xf32> + %9292 = stablehlo.add %9284, %9291 : tensor<3x32x1x128xf32> + %9293 = stablehlo.broadcast_in_dim %9275, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %9294 = stablehlo.broadcast_in_dim %9279, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %9295 = stablehlo.multiply %9293, %9294 : tensor<3x8x1x128xf32> + %9296 = stablehlo.slice %9275 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %9297 = stablehlo.slice %9275 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %9298 = stablehlo.negate %9297 : tensor<3x8x1x64xf32> + %9299 = stablehlo.concatenate %9298, %9296, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %9300 = stablehlo.broadcast_in_dim %9299, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %9301 = stablehlo.broadcast_in_dim %9281, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %9302 = stablehlo.multiply %9300, %9301 : tensor<3x8x1x128xf32> + %9303 = stablehlo.add %9295, %9302 : tensor<3x8x1x128xf32> + %9304 = stablehlo.concatenate %arg95, %9303, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %9305 = stablehlo.concatenate %arg96, %9275, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %9306 = stablehlo.reshape %9304 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %9307 = stablehlo.broadcast_in_dim %9306, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %9308 = stablehlo.reshape %9307 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %9309 = stablehlo.reshape %9305 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %9310 = stablehlo.broadcast_in_dim %9309, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %9311 = stablehlo.reshape %9310 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %9312 = stablehlo.transpose %9308, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %9313 = stablehlo.reshape %9292 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %9314 = stablehlo.reshape %9312 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %9315 = stablehlo.broadcast_in_dim %9314, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %9316 = stablehlo.dot_general %9313, %9315, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %9317 = stablehlo.reshape %9316 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %9318 = stablehlo.broadcast_in_dim %9317, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %9319 = stablehlo.divide %9318, %89 : tensor<3x32x1x8xf32> + %9320 = stablehlo.custom_call @byteir.softmax(%9319) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %9321 = stablehlo.reshape %9320 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %9322 = stablehlo.reshape %9311 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %9323 = stablehlo.broadcast_in_dim %9322, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %9324 = stablehlo.dot_general %9321, %9323, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %9325 = stablehlo.reshape %9324 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %9326 = stablehlo.transpose %9325, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %9327 = stablehlo.reshape %9326 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %9328 = stablehlo.reshape %9327 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %9329 = stablehlo.dot %9328, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %9330 = stablehlo.reshape %9329 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %9331 = stablehlo.add %9253, %9330 : tensor<3x1x4096xf32> + %9332 = stablehlo.broadcast_in_dim %9331, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %9333 = stablehlo.power %9332, %15 : tensor<3x1x4096xf32> + %9334 = stablehlo.reduce(%9333 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %9335 = stablehlo.reshape %9334 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %9336 = stablehlo.broadcast_in_dim %9335, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %9337 = stablehlo.divide %9336, %21 : tensor<3x1x1xf32> + %9338 = stablehlo.broadcast_in_dim %9337, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %9339 = stablehlo.add %9338, %25 : tensor<3x1x1xf32> + %9340 = stablehlo.rsqrt %9339 : tensor<3x1x1xf32> + %9341 = stablehlo.broadcast_in_dim %9340, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %9342 = stablehlo.multiply %9332, %9341 : tensor<3x1x4096xf32> + %9343 = stablehlo.broadcast_in_dim %9342, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %9344 = stablehlo.multiply %9343, %31 : tensor<3x1x4096xf32> + %9345 = stablehlo.reshape %9344 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %9346 = stablehlo.dot %9345, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %9347 = stablehlo.custom_call @byteir.softmax(%9346) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %9348:2 = stablehlo.custom_call @byteir.top_k(%9347) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %9349 = stablehlo.reduce(%9348#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %9350 = stablehlo.reshape %9349 : (tensor<3xf32>) -> tensor<3x1xf32> + %9351 = stablehlo.broadcast_in_dim %9348#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %9352 = stablehlo.broadcast_in_dim %9350, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %9353 = stablehlo.divide %9351, %9352 : tensor<3x2xf32> + %9354 = stablehlo.reshape %9348#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %9355 = stablehlo.broadcast_in_dim %9354, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %9356 = stablehlo.compare EQ, %9355, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %9357 = stablehlo.convert %9356 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %9358 = stablehlo.transpose %9357, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %9359 = stablehlo.slice %9358 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9360 = stablehlo.reshape %9359 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9361 = stablehlo.custom_call @byteir.non_zero(%9360) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3264 = tensor.dim %9361, %c0 : tensor + %9362 = arith.index_cast %dim_3264 : index to i64 + %from_elements_3265 = tensor.from_elements %9362, %c1_i64 : tensor<2xi64> + %9363 = stablehlo.real_dynamic_slice %9361, %c_22, %from_elements_3265, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3266 = tensor.dim %9363, %c0 : tensor + %9364 = arith.index_cast %dim_3266 : index to i64 + %from_elements_3267 = tensor.from_elements %9364 : tensor<1xi64> + %9365 = stablehlo.dynamic_reshape %9363, %from_elements_3267 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3268 = tensor.from_elements %9362, %c2_i64 : tensor<2xi64> + %9366 = stablehlo.real_dynamic_slice %9361, %c_24, %from_elements_3268, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3269 = tensor.dim %9366, %c0 : tensor + %9367 = arith.index_cast %dim_3269 : index to i64 + %from_elements_3270 = tensor.from_elements %9367 : tensor<1xi64> + %9368 = stablehlo.dynamic_reshape %9366, %from_elements_3270 : (tensor, tensor<1xi64>) -> tensor + %9369 = stablehlo.reshape %9345 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_3271 = tensor.dim %9368, %c0 : tensor + %9370 = arith.index_cast %dim_3271 : index to i64 + %from_elements_3272 = tensor.from_elements %9370, %c1_i64 : tensor<2xi64> + %9371 = stablehlo.dynamic_reshape %9368, %from_elements_3272 : (tensor, tensor<2xi64>) -> tensor + %dim_3273 = tensor.dim %9371, %c0 : tensor + %9372 = arith.index_cast %dim_3273 : index to i64 + %from_elements_3274 = tensor.from_elements %c1_i64, %9372, %c4096_i64 : tensor<3xi64> + %9373 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3274, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3275 = tensor.dim %9373, %c1 : tensor<1x?x4096xi64> + %9374 = arith.index_cast %dim_3275 : index to i64 + %from_elements_3276 = tensor.from_elements %c1_i64, %9374, %c4096_i64, %c1_i64 : tensor<4xi64> + %9375 = stablehlo.dynamic_reshape %9373, %from_elements_3276 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9376 = stablehlo.dynamic_broadcast_in_dim %9371, %from_elements_3274, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3277 = tensor.dim %9376, %c1 : tensor<1x?x4096xi64> + %9377 = arith.index_cast %dim_3277 : index to i64 + %from_elements_3278 = tensor.from_elements %c1_i64, %9377, %c4096_i64, %c1_i64 : tensor<4xi64> + %9378 = stablehlo.dynamic_reshape %9376, %from_elements_3278 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9379 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3274, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3279 = tensor.dim %9379, %c1 : tensor<1x?x4096xi64> + %9380 = arith.index_cast %dim_3279 : index to i64 + %from_elements_3280 = tensor.from_elements %c1_i64, %9380, %c4096_i64, %c1_i64 : tensor<4xi64> + %9381 = stablehlo.dynamic_reshape %9379, %from_elements_3280 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9382 = stablehlo.concatenate %9375, %9378, %9381, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9383 = "stablehlo.gather"(%9369, %9382) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9384 = shape.shape_of %9383 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9385 = shape.num_elements %9384 : tensor<3xindex> -> index + %9386 = stablehlo.compute_reshape_shape %9385, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9387 = stablehlo.dynamic_reshape %9383, %9386 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9388 = stablehlo.dot %9387, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9389 = stablehlo.logistic %9388 : tensor + %9390 = shape.shape_of %9389 : tensor -> tensor<2xindex> + %9391 = shape.shape_of %9388 : tensor -> tensor<2xindex> + %9392 = shape.cstr_broadcastable %9390, %9391 : tensor<2xindex>, tensor<2xindex> + %9393 = shape.assuming %9392 -> (tensor) { + %19688 = shape.broadcast %9390, %9391 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9389, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9388, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9394 = shape.shape_of %9393 : tensor -> tensor<2xindex> + %9395 = shape.cstr_broadcastable %9394, %9391 : tensor<2xindex>, tensor<2xindex> + %9396 = shape.assuming %9395 -> (tensor) { + %19688 = shape.broadcast %9394, %9391 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9393, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9388, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9397 = stablehlo.dot %9396, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %9398 = stablehlo.reshape %9353 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_3281 = tensor.dim %9368, %c0 : tensor + %9399 = arith.index_cast %dim_3281 : index to i64 + %from_elements_3282 = tensor.from_elements %9399, %c1_i64 : tensor<2xi64> + %9400 = stablehlo.dynamic_reshape %9368, %from_elements_3282 : (tensor, tensor<2xi64>) -> tensor + %dim_3283 = tensor.dim %9365, %c0 : tensor + %9401 = arith.index_cast %dim_3283 : index to i64 + %from_elements_3284 = tensor.from_elements %9401, %c1_i64 : tensor<2xi64> + %9402 = stablehlo.dynamic_reshape %9365, %from_elements_3284 : (tensor, tensor<2xi64>) -> tensor + %9403 = stablehlo.concatenate %9400, %9402, dim = 1 : (tensor, tensor) -> tensor + %9404 = "stablehlo.gather"(%9398, %9403) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9405 = shape.shape_of %9397 : tensor -> tensor<2xindex> + %9406 = shape.shape_of %9404 : tensor -> tensor<2xindex> + %9407 = shape.cstr_broadcastable %9405, %9406 : tensor<2xindex>, tensor<2xindex> + %9408 = shape.assuming %9407 -> (tensor) { + %19688 = shape.broadcast %9405, %9406 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9397, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9404, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9409 = shape.shape_of %9408 : tensor -> tensor<2xindex> + %9410 = stablehlo.dynamic_broadcast_in_dim %9408, %9409, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9411 = stablehlo.dynamic_broadcast_in_dim %213, %9409, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9412 = stablehlo.multiply %9410, %9411 : tensor + %dim_3285 = tensor.dim %9371, %c0 : tensor + %9413 = arith.index_cast %dim_3285 : index to i64 + %dim_3286 = tensor.dim %9408, %c0 : tensor + %9414 = arith.index_cast %dim_3286 : index to i64 + %9415 = arith.maxsi %9413, %9414 : i64 + %9416 = arith.index_cast %9415 : i64 to index + %from_elements_3287 = tensor.from_elements %9416, %c4096 : tensor<2xindex> + %9417 = stablehlo.dynamic_broadcast_in_dim %9371, %from_elements_3287, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3288 = tensor.dim %9417, %c0 : tensor + %9418 = arith.index_cast %dim_3288 : index to i64 + %from_elements_3289 = tensor.from_elements %9418, %c4096_i64 : tensor<2xi64> + %9419 = stablehlo.real_dynamic_slice %9412, %c_22, %from_elements_3289, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3290 = tensor.from_elements %9418, %c4096_i64, %c1_i64 : tensor<3xi64> + %9420 = stablehlo.dynamic_reshape %9417, %from_elements_3290 : (tensor, tensor<3xi64>) -> tensor + %9421 = stablehlo.dynamic_iota %from_elements_3290, dim = 1 : (tensor<3xi64>) -> tensor + %9422 = stablehlo.concatenate %9420, %9421, dim = 2 : (tensor, tensor) -> tensor + %9423 = "stablehlo.scatter"(%cst_2, %9422, %9419) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9424 = stablehlo.slice %9358 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9425 = stablehlo.reshape %9424 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9426 = stablehlo.custom_call @byteir.non_zero(%9425) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3291 = tensor.dim %9426, %c0 : tensor + %9427 = arith.index_cast %dim_3291 : index to i64 + %from_elements_3292 = tensor.from_elements %9427, %c1_i64 : tensor<2xi64> + %9428 = stablehlo.real_dynamic_slice %9426, %c_22, %from_elements_3292, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3293 = tensor.dim %9428, %c0 : tensor + %9429 = arith.index_cast %dim_3293 : index to i64 + %from_elements_3294 = tensor.from_elements %9429 : tensor<1xi64> + %9430 = stablehlo.dynamic_reshape %9428, %from_elements_3294 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3295 = tensor.from_elements %9427, %c2_i64 : tensor<2xi64> + %9431 = stablehlo.real_dynamic_slice %9426, %c_24, %from_elements_3295, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3296 = tensor.dim %9431, %c0 : tensor + %9432 = arith.index_cast %dim_3296 : index to i64 + %from_elements_3297 = tensor.from_elements %9432 : tensor<1xi64> + %9433 = stablehlo.dynamic_reshape %9431, %from_elements_3297 : (tensor, tensor<1xi64>) -> tensor + %dim_3298 = tensor.dim %9433, %c0 : tensor + %9434 = arith.index_cast %dim_3298 : index to i64 + %from_elements_3299 = tensor.from_elements %9434, %c1_i64 : tensor<2xi64> + %9435 = stablehlo.dynamic_reshape %9433, %from_elements_3299 : (tensor, tensor<2xi64>) -> tensor + %dim_3300 = tensor.dim %9435, %c0 : tensor + %9436 = arith.index_cast %dim_3300 : index to i64 + %from_elements_3301 = tensor.from_elements %c1_i64, %9436, %c4096_i64 : tensor<3xi64> + %9437 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3301, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3302 = tensor.dim %9437, %c1 : tensor<1x?x4096xi64> + %9438 = arith.index_cast %dim_3302 : index to i64 + %from_elements_3303 = tensor.from_elements %c1_i64, %9438, %c4096_i64, %c1_i64 : tensor<4xi64> + %9439 = stablehlo.dynamic_reshape %9437, %from_elements_3303 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9440 = stablehlo.dynamic_broadcast_in_dim %9435, %from_elements_3301, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3304 = tensor.dim %9440, %c1 : tensor<1x?x4096xi64> + %9441 = arith.index_cast %dim_3304 : index to i64 + %from_elements_3305 = tensor.from_elements %c1_i64, %9441, %c4096_i64, %c1_i64 : tensor<4xi64> + %9442 = stablehlo.dynamic_reshape %9440, %from_elements_3305 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9443 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3301, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3306 = tensor.dim %9443, %c1 : tensor<1x?x4096xi64> + %9444 = arith.index_cast %dim_3306 : index to i64 + %from_elements_3307 = tensor.from_elements %c1_i64, %9444, %c4096_i64, %c1_i64 : tensor<4xi64> + %9445 = stablehlo.dynamic_reshape %9443, %from_elements_3307 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9446 = stablehlo.concatenate %9439, %9442, %9445, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9447 = "stablehlo.gather"(%9369, %9446) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9448 = shape.shape_of %9447 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9449 = shape.num_elements %9448 : tensor<3xindex> -> index + %9450 = stablehlo.compute_reshape_shape %9449, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9451 = stablehlo.dynamic_reshape %9447, %9450 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9452 = stablehlo.dot %9451, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9453 = stablehlo.logistic %9452 : tensor + %9454 = shape.shape_of %9453 : tensor -> tensor<2xindex> + %9455 = shape.shape_of %9452 : tensor -> tensor<2xindex> + %9456 = shape.cstr_broadcastable %9454, %9455 : tensor<2xindex>, tensor<2xindex> + %9457 = shape.assuming %9456 -> (tensor) { + %19688 = shape.broadcast %9454, %9455 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9453, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9452, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9458 = shape.shape_of %9457 : tensor -> tensor<2xindex> + %9459 = shape.cstr_broadcastable %9458, %9455 : tensor<2xindex>, tensor<2xindex> + %9460 = shape.assuming %9459 -> (tensor) { + %19688 = shape.broadcast %9458, %9455 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9457, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9452, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9461 = stablehlo.dot %9460, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3308 = tensor.dim %9433, %c0 : tensor + %9462 = arith.index_cast %dim_3308 : index to i64 + %from_elements_3309 = tensor.from_elements %9462, %c1_i64 : tensor<2xi64> + %9463 = stablehlo.dynamic_reshape %9433, %from_elements_3309 : (tensor, tensor<2xi64>) -> tensor + %dim_3310 = tensor.dim %9430, %c0 : tensor + %9464 = arith.index_cast %dim_3310 : index to i64 + %from_elements_3311 = tensor.from_elements %9464, %c1_i64 : tensor<2xi64> + %9465 = stablehlo.dynamic_reshape %9430, %from_elements_3311 : (tensor, tensor<2xi64>) -> tensor + %9466 = stablehlo.concatenate %9463, %9465, dim = 1 : (tensor, tensor) -> tensor + %9467 = "stablehlo.gather"(%9398, %9466) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9468 = shape.shape_of %9461 : tensor -> tensor<2xindex> + %9469 = shape.shape_of %9467 : tensor -> tensor<2xindex> + %9470 = shape.cstr_broadcastable %9468, %9469 : tensor<2xindex>, tensor<2xindex> + %9471 = shape.assuming %9470 -> (tensor) { + %19688 = shape.broadcast %9468, %9469 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9461, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9467, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9472 = shape.shape_of %9471 : tensor -> tensor<2xindex> + %9473 = stablehlo.dynamic_broadcast_in_dim %9471, %9472, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9474 = stablehlo.dynamic_broadcast_in_dim %213, %9472, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9475 = stablehlo.multiply %9473, %9474 : tensor + %dim_3312 = tensor.dim %9435, %c0 : tensor + %9476 = arith.index_cast %dim_3312 : index to i64 + %dim_3313 = tensor.dim %9471, %c0 : tensor + %9477 = arith.index_cast %dim_3313 : index to i64 + %9478 = arith.maxsi %9476, %9477 : i64 + %9479 = arith.index_cast %9478 : i64 to index + %from_elements_3314 = tensor.from_elements %9479, %c4096 : tensor<2xindex> + %9480 = stablehlo.dynamic_broadcast_in_dim %9435, %from_elements_3314, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3315 = tensor.dim %9480, %c0 : tensor + %9481 = arith.index_cast %dim_3315 : index to i64 + %from_elements_3316 = tensor.from_elements %9481, %c4096_i64 : tensor<2xi64> + %9482 = stablehlo.real_dynamic_slice %9475, %c_22, %from_elements_3316, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3317 = tensor.from_elements %9481, %c4096_i64, %c1_i64 : tensor<3xi64> + %9483 = stablehlo.dynamic_reshape %9480, %from_elements_3317 : (tensor, tensor<3xi64>) -> tensor + %9484 = stablehlo.dynamic_iota %from_elements_3317, dim = 1 : (tensor<3xi64>) -> tensor + %9485 = stablehlo.concatenate %9483, %9484, dim = 2 : (tensor, tensor) -> tensor + %9486 = "stablehlo.scatter"(%9423, %9485, %9482) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9487 = stablehlo.slice %9358 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9488 = stablehlo.reshape %9487 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9489 = stablehlo.custom_call @byteir.non_zero(%9488) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3318 = tensor.dim %9489, %c0 : tensor + %9490 = arith.index_cast %dim_3318 : index to i64 + %from_elements_3319 = tensor.from_elements %9490, %c1_i64 : tensor<2xi64> + %9491 = stablehlo.real_dynamic_slice %9489, %c_22, %from_elements_3319, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3320 = tensor.dim %9491, %c0 : tensor + %9492 = arith.index_cast %dim_3320 : index to i64 + %from_elements_3321 = tensor.from_elements %9492 : tensor<1xi64> + %9493 = stablehlo.dynamic_reshape %9491, %from_elements_3321 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3322 = tensor.from_elements %9490, %c2_i64 : tensor<2xi64> + %9494 = stablehlo.real_dynamic_slice %9489, %c_24, %from_elements_3322, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3323 = tensor.dim %9494, %c0 : tensor + %9495 = arith.index_cast %dim_3323 : index to i64 + %from_elements_3324 = tensor.from_elements %9495 : tensor<1xi64> + %9496 = stablehlo.dynamic_reshape %9494, %from_elements_3324 : (tensor, tensor<1xi64>) -> tensor + %dim_3325 = tensor.dim %9496, %c0 : tensor + %9497 = arith.index_cast %dim_3325 : index to i64 + %from_elements_3326 = tensor.from_elements %9497, %c1_i64 : tensor<2xi64> + %9498 = stablehlo.dynamic_reshape %9496, %from_elements_3326 : (tensor, tensor<2xi64>) -> tensor + %dim_3327 = tensor.dim %9498, %c0 : tensor + %9499 = arith.index_cast %dim_3327 : index to i64 + %from_elements_3328 = tensor.from_elements %c1_i64, %9499, %c4096_i64 : tensor<3xi64> + %9500 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3328, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3329 = tensor.dim %9500, %c1 : tensor<1x?x4096xi64> + %9501 = arith.index_cast %dim_3329 : index to i64 + %from_elements_3330 = tensor.from_elements %c1_i64, %9501, %c4096_i64, %c1_i64 : tensor<4xi64> + %9502 = stablehlo.dynamic_reshape %9500, %from_elements_3330 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9503 = stablehlo.dynamic_broadcast_in_dim %9498, %from_elements_3328, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3331 = tensor.dim %9503, %c1 : tensor<1x?x4096xi64> + %9504 = arith.index_cast %dim_3331 : index to i64 + %from_elements_3332 = tensor.from_elements %c1_i64, %9504, %c4096_i64, %c1_i64 : tensor<4xi64> + %9505 = stablehlo.dynamic_reshape %9503, %from_elements_3332 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9506 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3328, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3333 = tensor.dim %9506, %c1 : tensor<1x?x4096xi64> + %9507 = arith.index_cast %dim_3333 : index to i64 + %from_elements_3334 = tensor.from_elements %c1_i64, %9507, %c4096_i64, %c1_i64 : tensor<4xi64> + %9508 = stablehlo.dynamic_reshape %9506, %from_elements_3334 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9509 = stablehlo.concatenate %9502, %9505, %9508, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9510 = "stablehlo.gather"(%9369, %9509) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9511 = shape.shape_of %9510 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9512 = shape.num_elements %9511 : tensor<3xindex> -> index + %9513 = stablehlo.compute_reshape_shape %9512, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9514 = stablehlo.dynamic_reshape %9510, %9513 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9515 = stablehlo.dot %9514, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9516 = stablehlo.logistic %9515 : tensor + %9517 = shape.shape_of %9516 : tensor -> tensor<2xindex> + %9518 = shape.shape_of %9515 : tensor -> tensor<2xindex> + %9519 = shape.cstr_broadcastable %9517, %9518 : tensor<2xindex>, tensor<2xindex> + %9520 = shape.assuming %9519 -> (tensor) { + %19688 = shape.broadcast %9517, %9518 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9516, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9515, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9521 = shape.shape_of %9520 : tensor -> tensor<2xindex> + %9522 = shape.cstr_broadcastable %9521, %9518 : tensor<2xindex>, tensor<2xindex> + %9523 = shape.assuming %9522 -> (tensor) { + %19688 = shape.broadcast %9521, %9518 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9520, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9515, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9524 = stablehlo.dot %9523, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3335 = tensor.dim %9496, %c0 : tensor + %9525 = arith.index_cast %dim_3335 : index to i64 + %from_elements_3336 = tensor.from_elements %9525, %c1_i64 : tensor<2xi64> + %9526 = stablehlo.dynamic_reshape %9496, %from_elements_3336 : (tensor, tensor<2xi64>) -> tensor + %dim_3337 = tensor.dim %9493, %c0 : tensor + %9527 = arith.index_cast %dim_3337 : index to i64 + %from_elements_3338 = tensor.from_elements %9527, %c1_i64 : tensor<2xi64> + %9528 = stablehlo.dynamic_reshape %9493, %from_elements_3338 : (tensor, tensor<2xi64>) -> tensor + %9529 = stablehlo.concatenate %9526, %9528, dim = 1 : (tensor, tensor) -> tensor + %9530 = "stablehlo.gather"(%9398, %9529) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9531 = shape.shape_of %9524 : tensor -> tensor<2xindex> + %9532 = shape.shape_of %9530 : tensor -> tensor<2xindex> + %9533 = shape.cstr_broadcastable %9531, %9532 : tensor<2xindex>, tensor<2xindex> + %9534 = shape.assuming %9533 -> (tensor) { + %19688 = shape.broadcast %9531, %9532 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9524, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9530, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9535 = shape.shape_of %9534 : tensor -> tensor<2xindex> + %9536 = stablehlo.dynamic_broadcast_in_dim %9534, %9535, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9537 = stablehlo.dynamic_broadcast_in_dim %213, %9535, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9538 = stablehlo.multiply %9536, %9537 : tensor + %dim_3339 = tensor.dim %9498, %c0 : tensor + %9539 = arith.index_cast %dim_3339 : index to i64 + %dim_3340 = tensor.dim %9534, %c0 : tensor + %9540 = arith.index_cast %dim_3340 : index to i64 + %9541 = arith.maxsi %9539, %9540 : i64 + %9542 = arith.index_cast %9541 : i64 to index + %from_elements_3341 = tensor.from_elements %9542, %c4096 : tensor<2xindex> + %9543 = stablehlo.dynamic_broadcast_in_dim %9498, %from_elements_3341, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3342 = tensor.dim %9543, %c0 : tensor + %9544 = arith.index_cast %dim_3342 : index to i64 + %from_elements_3343 = tensor.from_elements %9544, %c4096_i64 : tensor<2xi64> + %9545 = stablehlo.real_dynamic_slice %9538, %c_22, %from_elements_3343, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3344 = tensor.from_elements %9544, %c4096_i64, %c1_i64 : tensor<3xi64> + %9546 = stablehlo.dynamic_reshape %9543, %from_elements_3344 : (tensor, tensor<3xi64>) -> tensor + %9547 = stablehlo.dynamic_iota %from_elements_3344, dim = 1 : (tensor<3xi64>) -> tensor + %9548 = stablehlo.concatenate %9546, %9547, dim = 2 : (tensor, tensor) -> tensor + %9549 = "stablehlo.scatter"(%9486, %9548, %9545) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9550 = stablehlo.slice %9358 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9551 = stablehlo.reshape %9550 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9552 = stablehlo.custom_call @byteir.non_zero(%9551) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3345 = tensor.dim %9552, %c0 : tensor + %9553 = arith.index_cast %dim_3345 : index to i64 + %from_elements_3346 = tensor.from_elements %9553, %c1_i64 : tensor<2xi64> + %9554 = stablehlo.real_dynamic_slice %9552, %c_22, %from_elements_3346, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3347 = tensor.dim %9554, %c0 : tensor + %9555 = arith.index_cast %dim_3347 : index to i64 + %from_elements_3348 = tensor.from_elements %9555 : tensor<1xi64> + %9556 = stablehlo.dynamic_reshape %9554, %from_elements_3348 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3349 = tensor.from_elements %9553, %c2_i64 : tensor<2xi64> + %9557 = stablehlo.real_dynamic_slice %9552, %c_24, %from_elements_3349, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3350 = tensor.dim %9557, %c0 : tensor + %9558 = arith.index_cast %dim_3350 : index to i64 + %from_elements_3351 = tensor.from_elements %9558 : tensor<1xi64> + %9559 = stablehlo.dynamic_reshape %9557, %from_elements_3351 : (tensor, tensor<1xi64>) -> tensor + %dim_3352 = tensor.dim %9559, %c0 : tensor + %9560 = arith.index_cast %dim_3352 : index to i64 + %from_elements_3353 = tensor.from_elements %9560, %c1_i64 : tensor<2xi64> + %9561 = stablehlo.dynamic_reshape %9559, %from_elements_3353 : (tensor, tensor<2xi64>) -> tensor + %dim_3354 = tensor.dim %9561, %c0 : tensor + %9562 = arith.index_cast %dim_3354 : index to i64 + %from_elements_3355 = tensor.from_elements %c1_i64, %9562, %c4096_i64 : tensor<3xi64> + %9563 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3355, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3356 = tensor.dim %9563, %c1 : tensor<1x?x4096xi64> + %9564 = arith.index_cast %dim_3356 : index to i64 + %from_elements_3357 = tensor.from_elements %c1_i64, %9564, %c4096_i64, %c1_i64 : tensor<4xi64> + %9565 = stablehlo.dynamic_reshape %9563, %from_elements_3357 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9566 = stablehlo.dynamic_broadcast_in_dim %9561, %from_elements_3355, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3358 = tensor.dim %9566, %c1 : tensor<1x?x4096xi64> + %9567 = arith.index_cast %dim_3358 : index to i64 + %from_elements_3359 = tensor.from_elements %c1_i64, %9567, %c4096_i64, %c1_i64 : tensor<4xi64> + %9568 = stablehlo.dynamic_reshape %9566, %from_elements_3359 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9569 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3355, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3360 = tensor.dim %9569, %c1 : tensor<1x?x4096xi64> + %9570 = arith.index_cast %dim_3360 : index to i64 + %from_elements_3361 = tensor.from_elements %c1_i64, %9570, %c4096_i64, %c1_i64 : tensor<4xi64> + %9571 = stablehlo.dynamic_reshape %9569, %from_elements_3361 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9572 = stablehlo.concatenate %9565, %9568, %9571, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9573 = "stablehlo.gather"(%9369, %9572) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9574 = shape.shape_of %9573 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9575 = shape.num_elements %9574 : tensor<3xindex> -> index + %9576 = stablehlo.compute_reshape_shape %9575, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9577 = stablehlo.dynamic_reshape %9573, %9576 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9578 = stablehlo.dot %9577, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9579 = stablehlo.logistic %9578 : tensor + %9580 = shape.shape_of %9579 : tensor -> tensor<2xindex> + %9581 = shape.shape_of %9578 : tensor -> tensor<2xindex> + %9582 = shape.cstr_broadcastable %9580, %9581 : tensor<2xindex>, tensor<2xindex> + %9583 = shape.assuming %9582 -> (tensor) { + %19688 = shape.broadcast %9580, %9581 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9579, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9578, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9584 = shape.shape_of %9583 : tensor -> tensor<2xindex> + %9585 = shape.cstr_broadcastable %9584, %9581 : tensor<2xindex>, tensor<2xindex> + %9586 = shape.assuming %9585 -> (tensor) { + %19688 = shape.broadcast %9584, %9581 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9583, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9578, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9587 = stablehlo.dot %9586, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3362 = tensor.dim %9559, %c0 : tensor + %9588 = arith.index_cast %dim_3362 : index to i64 + %from_elements_3363 = tensor.from_elements %9588, %c1_i64 : tensor<2xi64> + %9589 = stablehlo.dynamic_reshape %9559, %from_elements_3363 : (tensor, tensor<2xi64>) -> tensor + %dim_3364 = tensor.dim %9556, %c0 : tensor + %9590 = arith.index_cast %dim_3364 : index to i64 + %from_elements_3365 = tensor.from_elements %9590, %c1_i64 : tensor<2xi64> + %9591 = stablehlo.dynamic_reshape %9556, %from_elements_3365 : (tensor, tensor<2xi64>) -> tensor + %9592 = stablehlo.concatenate %9589, %9591, dim = 1 : (tensor, tensor) -> tensor + %9593 = "stablehlo.gather"(%9398, %9592) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9594 = shape.shape_of %9587 : tensor -> tensor<2xindex> + %9595 = shape.shape_of %9593 : tensor -> tensor<2xindex> + %9596 = shape.cstr_broadcastable %9594, %9595 : tensor<2xindex>, tensor<2xindex> + %9597 = shape.assuming %9596 -> (tensor) { + %19688 = shape.broadcast %9594, %9595 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9587, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9593, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9598 = shape.shape_of %9597 : tensor -> tensor<2xindex> + %9599 = stablehlo.dynamic_broadcast_in_dim %9597, %9598, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9600 = stablehlo.dynamic_broadcast_in_dim %213, %9598, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9601 = stablehlo.multiply %9599, %9600 : tensor + %dim_3366 = tensor.dim %9561, %c0 : tensor + %9602 = arith.index_cast %dim_3366 : index to i64 + %dim_3367 = tensor.dim %9597, %c0 : tensor + %9603 = arith.index_cast %dim_3367 : index to i64 + %9604 = arith.maxsi %9602, %9603 : i64 + %9605 = arith.index_cast %9604 : i64 to index + %from_elements_3368 = tensor.from_elements %9605, %c4096 : tensor<2xindex> + %9606 = stablehlo.dynamic_broadcast_in_dim %9561, %from_elements_3368, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3369 = tensor.dim %9606, %c0 : tensor + %9607 = arith.index_cast %dim_3369 : index to i64 + %from_elements_3370 = tensor.from_elements %9607, %c4096_i64 : tensor<2xi64> + %9608 = stablehlo.real_dynamic_slice %9601, %c_22, %from_elements_3370, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3371 = tensor.from_elements %9607, %c4096_i64, %c1_i64 : tensor<3xi64> + %9609 = stablehlo.dynamic_reshape %9606, %from_elements_3371 : (tensor, tensor<3xi64>) -> tensor + %9610 = stablehlo.dynamic_iota %from_elements_3371, dim = 1 : (tensor<3xi64>) -> tensor + %9611 = stablehlo.concatenate %9609, %9610, dim = 2 : (tensor, tensor) -> tensor + %9612 = "stablehlo.scatter"(%9549, %9611, %9608) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9613 = stablehlo.slice %9358 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9614 = stablehlo.reshape %9613 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9615 = stablehlo.custom_call @byteir.non_zero(%9614) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3372 = tensor.dim %9615, %c0 : tensor + %9616 = arith.index_cast %dim_3372 : index to i64 + %from_elements_3373 = tensor.from_elements %9616, %c1_i64 : tensor<2xi64> + %9617 = stablehlo.real_dynamic_slice %9615, %c_22, %from_elements_3373, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3374 = tensor.dim %9617, %c0 : tensor + %9618 = arith.index_cast %dim_3374 : index to i64 + %from_elements_3375 = tensor.from_elements %9618 : tensor<1xi64> + %9619 = stablehlo.dynamic_reshape %9617, %from_elements_3375 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3376 = tensor.from_elements %9616, %c2_i64 : tensor<2xi64> + %9620 = stablehlo.real_dynamic_slice %9615, %c_24, %from_elements_3376, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3377 = tensor.dim %9620, %c0 : tensor + %9621 = arith.index_cast %dim_3377 : index to i64 + %from_elements_3378 = tensor.from_elements %9621 : tensor<1xi64> + %9622 = stablehlo.dynamic_reshape %9620, %from_elements_3378 : (tensor, tensor<1xi64>) -> tensor + %dim_3379 = tensor.dim %9622, %c0 : tensor + %9623 = arith.index_cast %dim_3379 : index to i64 + %from_elements_3380 = tensor.from_elements %9623, %c1_i64 : tensor<2xi64> + %9624 = stablehlo.dynamic_reshape %9622, %from_elements_3380 : (tensor, tensor<2xi64>) -> tensor + %dim_3381 = tensor.dim %9624, %c0 : tensor + %9625 = arith.index_cast %dim_3381 : index to i64 + %from_elements_3382 = tensor.from_elements %c1_i64, %9625, %c4096_i64 : tensor<3xi64> + %9626 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3382, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3383 = tensor.dim %9626, %c1 : tensor<1x?x4096xi64> + %9627 = arith.index_cast %dim_3383 : index to i64 + %from_elements_3384 = tensor.from_elements %c1_i64, %9627, %c4096_i64, %c1_i64 : tensor<4xi64> + %9628 = stablehlo.dynamic_reshape %9626, %from_elements_3384 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9629 = stablehlo.dynamic_broadcast_in_dim %9624, %from_elements_3382, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3385 = tensor.dim %9629, %c1 : tensor<1x?x4096xi64> + %9630 = arith.index_cast %dim_3385 : index to i64 + %from_elements_3386 = tensor.from_elements %c1_i64, %9630, %c4096_i64, %c1_i64 : tensor<4xi64> + %9631 = stablehlo.dynamic_reshape %9629, %from_elements_3386 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9632 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3382, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3387 = tensor.dim %9632, %c1 : tensor<1x?x4096xi64> + %9633 = arith.index_cast %dim_3387 : index to i64 + %from_elements_3388 = tensor.from_elements %c1_i64, %9633, %c4096_i64, %c1_i64 : tensor<4xi64> + %9634 = stablehlo.dynamic_reshape %9632, %from_elements_3388 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9635 = stablehlo.concatenate %9628, %9631, %9634, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9636 = "stablehlo.gather"(%9369, %9635) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9637 = shape.shape_of %9636 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9638 = shape.num_elements %9637 : tensor<3xindex> -> index + %9639 = stablehlo.compute_reshape_shape %9638, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9640 = stablehlo.dynamic_reshape %9636, %9639 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9641 = stablehlo.dot %9640, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9642 = stablehlo.logistic %9641 : tensor + %9643 = shape.shape_of %9642 : tensor -> tensor<2xindex> + %9644 = shape.shape_of %9641 : tensor -> tensor<2xindex> + %9645 = shape.cstr_broadcastable %9643, %9644 : tensor<2xindex>, tensor<2xindex> + %9646 = shape.assuming %9645 -> (tensor) { + %19688 = shape.broadcast %9643, %9644 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9642, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9641, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9647 = shape.shape_of %9646 : tensor -> tensor<2xindex> + %9648 = shape.cstr_broadcastable %9647, %9644 : tensor<2xindex>, tensor<2xindex> + %9649 = shape.assuming %9648 -> (tensor) { + %19688 = shape.broadcast %9647, %9644 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9646, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9641, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9650 = stablehlo.dot %9649, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3389 = tensor.dim %9622, %c0 : tensor + %9651 = arith.index_cast %dim_3389 : index to i64 + %from_elements_3390 = tensor.from_elements %9651, %c1_i64 : tensor<2xi64> + %9652 = stablehlo.dynamic_reshape %9622, %from_elements_3390 : (tensor, tensor<2xi64>) -> tensor + %dim_3391 = tensor.dim %9619, %c0 : tensor + %9653 = arith.index_cast %dim_3391 : index to i64 + %from_elements_3392 = tensor.from_elements %9653, %c1_i64 : tensor<2xi64> + %9654 = stablehlo.dynamic_reshape %9619, %from_elements_3392 : (tensor, tensor<2xi64>) -> tensor + %9655 = stablehlo.concatenate %9652, %9654, dim = 1 : (tensor, tensor) -> tensor + %9656 = "stablehlo.gather"(%9398, %9655) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9657 = shape.shape_of %9650 : tensor -> tensor<2xindex> + %9658 = shape.shape_of %9656 : tensor -> tensor<2xindex> + %9659 = shape.cstr_broadcastable %9657, %9658 : tensor<2xindex>, tensor<2xindex> + %9660 = shape.assuming %9659 -> (tensor) { + %19688 = shape.broadcast %9657, %9658 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9650, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9656, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9661 = shape.shape_of %9660 : tensor -> tensor<2xindex> + %9662 = stablehlo.dynamic_broadcast_in_dim %9660, %9661, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9663 = stablehlo.dynamic_broadcast_in_dim %213, %9661, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9664 = stablehlo.multiply %9662, %9663 : tensor + %dim_3393 = tensor.dim %9624, %c0 : tensor + %9665 = arith.index_cast %dim_3393 : index to i64 + %dim_3394 = tensor.dim %9660, %c0 : tensor + %9666 = arith.index_cast %dim_3394 : index to i64 + %9667 = arith.maxsi %9665, %9666 : i64 + %9668 = arith.index_cast %9667 : i64 to index + %from_elements_3395 = tensor.from_elements %9668, %c4096 : tensor<2xindex> + %9669 = stablehlo.dynamic_broadcast_in_dim %9624, %from_elements_3395, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3396 = tensor.dim %9669, %c0 : tensor + %9670 = arith.index_cast %dim_3396 : index to i64 + %from_elements_3397 = tensor.from_elements %9670, %c4096_i64 : tensor<2xi64> + %9671 = stablehlo.real_dynamic_slice %9664, %c_22, %from_elements_3397, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3398 = tensor.from_elements %9670, %c4096_i64, %c1_i64 : tensor<3xi64> + %9672 = stablehlo.dynamic_reshape %9669, %from_elements_3398 : (tensor, tensor<3xi64>) -> tensor + %9673 = stablehlo.dynamic_iota %from_elements_3398, dim = 1 : (tensor<3xi64>) -> tensor + %9674 = stablehlo.concatenate %9672, %9673, dim = 2 : (tensor, tensor) -> tensor + %9675 = "stablehlo.scatter"(%9612, %9674, %9671) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9676 = stablehlo.slice %9358 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9677 = stablehlo.reshape %9676 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9678 = stablehlo.custom_call @byteir.non_zero(%9677) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3399 = tensor.dim %9678, %c0 : tensor + %9679 = arith.index_cast %dim_3399 : index to i64 + %from_elements_3400 = tensor.from_elements %9679, %c1_i64 : tensor<2xi64> + %9680 = stablehlo.real_dynamic_slice %9678, %c_22, %from_elements_3400, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3401 = tensor.dim %9680, %c0 : tensor + %9681 = arith.index_cast %dim_3401 : index to i64 + %from_elements_3402 = tensor.from_elements %9681 : tensor<1xi64> + %9682 = stablehlo.dynamic_reshape %9680, %from_elements_3402 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3403 = tensor.from_elements %9679, %c2_i64 : tensor<2xi64> + %9683 = stablehlo.real_dynamic_slice %9678, %c_24, %from_elements_3403, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3404 = tensor.dim %9683, %c0 : tensor + %9684 = arith.index_cast %dim_3404 : index to i64 + %from_elements_3405 = tensor.from_elements %9684 : tensor<1xi64> + %9685 = stablehlo.dynamic_reshape %9683, %from_elements_3405 : (tensor, tensor<1xi64>) -> tensor + %dim_3406 = tensor.dim %9685, %c0 : tensor + %9686 = arith.index_cast %dim_3406 : index to i64 + %from_elements_3407 = tensor.from_elements %9686, %c1_i64 : tensor<2xi64> + %9687 = stablehlo.dynamic_reshape %9685, %from_elements_3407 : (tensor, tensor<2xi64>) -> tensor + %dim_3408 = tensor.dim %9687, %c0 : tensor + %9688 = arith.index_cast %dim_3408 : index to i64 + %from_elements_3409 = tensor.from_elements %c1_i64, %9688, %c4096_i64 : tensor<3xi64> + %9689 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3409, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3410 = tensor.dim %9689, %c1 : tensor<1x?x4096xi64> + %9690 = arith.index_cast %dim_3410 : index to i64 + %from_elements_3411 = tensor.from_elements %c1_i64, %9690, %c4096_i64, %c1_i64 : tensor<4xi64> + %9691 = stablehlo.dynamic_reshape %9689, %from_elements_3411 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9692 = stablehlo.dynamic_broadcast_in_dim %9687, %from_elements_3409, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3412 = tensor.dim %9692, %c1 : tensor<1x?x4096xi64> + %9693 = arith.index_cast %dim_3412 : index to i64 + %from_elements_3413 = tensor.from_elements %c1_i64, %9693, %c4096_i64, %c1_i64 : tensor<4xi64> + %9694 = stablehlo.dynamic_reshape %9692, %from_elements_3413 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9695 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3409, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3414 = tensor.dim %9695, %c1 : tensor<1x?x4096xi64> + %9696 = arith.index_cast %dim_3414 : index to i64 + %from_elements_3415 = tensor.from_elements %c1_i64, %9696, %c4096_i64, %c1_i64 : tensor<4xi64> + %9697 = stablehlo.dynamic_reshape %9695, %from_elements_3415 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9698 = stablehlo.concatenate %9691, %9694, %9697, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9699 = "stablehlo.gather"(%9369, %9698) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9700 = shape.shape_of %9699 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9701 = shape.num_elements %9700 : tensor<3xindex> -> index + %9702 = stablehlo.compute_reshape_shape %9701, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9703 = stablehlo.dynamic_reshape %9699, %9702 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9704 = stablehlo.dot %9703, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9705 = stablehlo.logistic %9704 : tensor + %9706 = shape.shape_of %9705 : tensor -> tensor<2xindex> + %9707 = shape.shape_of %9704 : tensor -> tensor<2xindex> + %9708 = shape.cstr_broadcastable %9706, %9707 : tensor<2xindex>, tensor<2xindex> + %9709 = shape.assuming %9708 -> (tensor) { + %19688 = shape.broadcast %9706, %9707 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9705, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9704, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9710 = shape.shape_of %9709 : tensor -> tensor<2xindex> + %9711 = shape.cstr_broadcastable %9710, %9707 : tensor<2xindex>, tensor<2xindex> + %9712 = shape.assuming %9711 -> (tensor) { + %19688 = shape.broadcast %9710, %9707 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9709, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9704, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9713 = stablehlo.dot %9712, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3416 = tensor.dim %9685, %c0 : tensor + %9714 = arith.index_cast %dim_3416 : index to i64 + %from_elements_3417 = tensor.from_elements %9714, %c1_i64 : tensor<2xi64> + %9715 = stablehlo.dynamic_reshape %9685, %from_elements_3417 : (tensor, tensor<2xi64>) -> tensor + %dim_3418 = tensor.dim %9682, %c0 : tensor + %9716 = arith.index_cast %dim_3418 : index to i64 + %from_elements_3419 = tensor.from_elements %9716, %c1_i64 : tensor<2xi64> + %9717 = stablehlo.dynamic_reshape %9682, %from_elements_3419 : (tensor, tensor<2xi64>) -> tensor + %9718 = stablehlo.concatenate %9715, %9717, dim = 1 : (tensor, tensor) -> tensor + %9719 = "stablehlo.gather"(%9398, %9718) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9720 = shape.shape_of %9713 : tensor -> tensor<2xindex> + %9721 = shape.shape_of %9719 : tensor -> tensor<2xindex> + %9722 = shape.cstr_broadcastable %9720, %9721 : tensor<2xindex>, tensor<2xindex> + %9723 = shape.assuming %9722 -> (tensor) { + %19688 = shape.broadcast %9720, %9721 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9713, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9719, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9724 = shape.shape_of %9723 : tensor -> tensor<2xindex> + %9725 = stablehlo.dynamic_broadcast_in_dim %9723, %9724, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9726 = stablehlo.dynamic_broadcast_in_dim %213, %9724, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9727 = stablehlo.multiply %9725, %9726 : tensor + %dim_3420 = tensor.dim %9687, %c0 : tensor + %9728 = arith.index_cast %dim_3420 : index to i64 + %dim_3421 = tensor.dim %9723, %c0 : tensor + %9729 = arith.index_cast %dim_3421 : index to i64 + %9730 = arith.maxsi %9728, %9729 : i64 + %9731 = arith.index_cast %9730 : i64 to index + %from_elements_3422 = tensor.from_elements %9731, %c4096 : tensor<2xindex> + %9732 = stablehlo.dynamic_broadcast_in_dim %9687, %from_elements_3422, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3423 = tensor.dim %9732, %c0 : tensor + %9733 = arith.index_cast %dim_3423 : index to i64 + %from_elements_3424 = tensor.from_elements %9733, %c4096_i64 : tensor<2xi64> + %9734 = stablehlo.real_dynamic_slice %9727, %c_22, %from_elements_3424, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3425 = tensor.from_elements %9733, %c4096_i64, %c1_i64 : tensor<3xi64> + %9735 = stablehlo.dynamic_reshape %9732, %from_elements_3425 : (tensor, tensor<3xi64>) -> tensor + %9736 = stablehlo.dynamic_iota %from_elements_3425, dim = 1 : (tensor<3xi64>) -> tensor + %9737 = stablehlo.concatenate %9735, %9736, dim = 2 : (tensor, tensor) -> tensor + %9738 = "stablehlo.scatter"(%9675, %9737, %9734) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9739 = stablehlo.slice %9358 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9740 = stablehlo.reshape %9739 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9741 = stablehlo.custom_call @byteir.non_zero(%9740) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3426 = tensor.dim %9741, %c0 : tensor + %9742 = arith.index_cast %dim_3426 : index to i64 + %from_elements_3427 = tensor.from_elements %9742, %c1_i64 : tensor<2xi64> + %9743 = stablehlo.real_dynamic_slice %9741, %c_22, %from_elements_3427, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3428 = tensor.dim %9743, %c0 : tensor + %9744 = arith.index_cast %dim_3428 : index to i64 + %from_elements_3429 = tensor.from_elements %9744 : tensor<1xi64> + %9745 = stablehlo.dynamic_reshape %9743, %from_elements_3429 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3430 = tensor.from_elements %9742, %c2_i64 : tensor<2xi64> + %9746 = stablehlo.real_dynamic_slice %9741, %c_24, %from_elements_3430, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3431 = tensor.dim %9746, %c0 : tensor + %9747 = arith.index_cast %dim_3431 : index to i64 + %from_elements_3432 = tensor.from_elements %9747 : tensor<1xi64> + %9748 = stablehlo.dynamic_reshape %9746, %from_elements_3432 : (tensor, tensor<1xi64>) -> tensor + %dim_3433 = tensor.dim %9748, %c0 : tensor + %9749 = arith.index_cast %dim_3433 : index to i64 + %from_elements_3434 = tensor.from_elements %9749, %c1_i64 : tensor<2xi64> + %9750 = stablehlo.dynamic_reshape %9748, %from_elements_3434 : (tensor, tensor<2xi64>) -> tensor + %dim_3435 = tensor.dim %9750, %c0 : tensor + %9751 = arith.index_cast %dim_3435 : index to i64 + %from_elements_3436 = tensor.from_elements %c1_i64, %9751, %c4096_i64 : tensor<3xi64> + %9752 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3436, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3437 = tensor.dim %9752, %c1 : tensor<1x?x4096xi64> + %9753 = arith.index_cast %dim_3437 : index to i64 + %from_elements_3438 = tensor.from_elements %c1_i64, %9753, %c4096_i64, %c1_i64 : tensor<4xi64> + %9754 = stablehlo.dynamic_reshape %9752, %from_elements_3438 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9755 = stablehlo.dynamic_broadcast_in_dim %9750, %from_elements_3436, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3439 = tensor.dim %9755, %c1 : tensor<1x?x4096xi64> + %9756 = arith.index_cast %dim_3439 : index to i64 + %from_elements_3440 = tensor.from_elements %c1_i64, %9756, %c4096_i64, %c1_i64 : tensor<4xi64> + %9757 = stablehlo.dynamic_reshape %9755, %from_elements_3440 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9758 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3436, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3441 = tensor.dim %9758, %c1 : tensor<1x?x4096xi64> + %9759 = arith.index_cast %dim_3441 : index to i64 + %from_elements_3442 = tensor.from_elements %c1_i64, %9759, %c4096_i64, %c1_i64 : tensor<4xi64> + %9760 = stablehlo.dynamic_reshape %9758, %from_elements_3442 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9761 = stablehlo.concatenate %9754, %9757, %9760, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9762 = "stablehlo.gather"(%9369, %9761) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9763 = shape.shape_of %9762 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9764 = shape.num_elements %9763 : tensor<3xindex> -> index + %9765 = stablehlo.compute_reshape_shape %9764, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9766 = stablehlo.dynamic_reshape %9762, %9765 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9767 = stablehlo.dot %9766, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9768 = stablehlo.logistic %9767 : tensor + %9769 = shape.shape_of %9768 : tensor -> tensor<2xindex> + %9770 = shape.shape_of %9767 : tensor -> tensor<2xindex> + %9771 = shape.cstr_broadcastable %9769, %9770 : tensor<2xindex>, tensor<2xindex> + %9772 = shape.assuming %9771 -> (tensor) { + %19688 = shape.broadcast %9769, %9770 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9768, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9767, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9773 = shape.shape_of %9772 : tensor -> tensor<2xindex> + %9774 = shape.cstr_broadcastable %9773, %9770 : tensor<2xindex>, tensor<2xindex> + %9775 = shape.assuming %9774 -> (tensor) { + %19688 = shape.broadcast %9773, %9770 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9772, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9767, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9776 = stablehlo.dot %9775, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3443 = tensor.dim %9748, %c0 : tensor + %9777 = arith.index_cast %dim_3443 : index to i64 + %from_elements_3444 = tensor.from_elements %9777, %c1_i64 : tensor<2xi64> + %9778 = stablehlo.dynamic_reshape %9748, %from_elements_3444 : (tensor, tensor<2xi64>) -> tensor + %dim_3445 = tensor.dim %9745, %c0 : tensor + %9779 = arith.index_cast %dim_3445 : index to i64 + %from_elements_3446 = tensor.from_elements %9779, %c1_i64 : tensor<2xi64> + %9780 = stablehlo.dynamic_reshape %9745, %from_elements_3446 : (tensor, tensor<2xi64>) -> tensor + %9781 = stablehlo.concatenate %9778, %9780, dim = 1 : (tensor, tensor) -> tensor + %9782 = "stablehlo.gather"(%9398, %9781) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9783 = shape.shape_of %9776 : tensor -> tensor<2xindex> + %9784 = shape.shape_of %9782 : tensor -> tensor<2xindex> + %9785 = shape.cstr_broadcastable %9783, %9784 : tensor<2xindex>, tensor<2xindex> + %9786 = shape.assuming %9785 -> (tensor) { + %19688 = shape.broadcast %9783, %9784 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9776, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9782, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9787 = shape.shape_of %9786 : tensor -> tensor<2xindex> + %9788 = stablehlo.dynamic_broadcast_in_dim %9786, %9787, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9789 = stablehlo.dynamic_broadcast_in_dim %213, %9787, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9790 = stablehlo.multiply %9788, %9789 : tensor + %dim_3447 = tensor.dim %9750, %c0 : tensor + %9791 = arith.index_cast %dim_3447 : index to i64 + %dim_3448 = tensor.dim %9786, %c0 : tensor + %9792 = arith.index_cast %dim_3448 : index to i64 + %9793 = arith.maxsi %9791, %9792 : i64 + %9794 = arith.index_cast %9793 : i64 to index + %from_elements_3449 = tensor.from_elements %9794, %c4096 : tensor<2xindex> + %9795 = stablehlo.dynamic_broadcast_in_dim %9750, %from_elements_3449, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3450 = tensor.dim %9795, %c0 : tensor + %9796 = arith.index_cast %dim_3450 : index to i64 + %from_elements_3451 = tensor.from_elements %9796, %c4096_i64 : tensor<2xi64> + %9797 = stablehlo.real_dynamic_slice %9790, %c_22, %from_elements_3451, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3452 = tensor.from_elements %9796, %c4096_i64, %c1_i64 : tensor<3xi64> + %9798 = stablehlo.dynamic_reshape %9795, %from_elements_3452 : (tensor, tensor<3xi64>) -> tensor + %9799 = stablehlo.dynamic_iota %from_elements_3452, dim = 1 : (tensor<3xi64>) -> tensor + %9800 = stablehlo.concatenate %9798, %9799, dim = 2 : (tensor, tensor) -> tensor + %9801 = "stablehlo.scatter"(%9738, %9800, %9797) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9802 = stablehlo.slice %9358 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9803 = stablehlo.reshape %9802 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9804 = stablehlo.custom_call @byteir.non_zero(%9803) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3453 = tensor.dim %9804, %c0 : tensor + %9805 = arith.index_cast %dim_3453 : index to i64 + %from_elements_3454 = tensor.from_elements %9805, %c1_i64 : tensor<2xi64> + %9806 = stablehlo.real_dynamic_slice %9804, %c_22, %from_elements_3454, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3455 = tensor.dim %9806, %c0 : tensor + %9807 = arith.index_cast %dim_3455 : index to i64 + %from_elements_3456 = tensor.from_elements %9807 : tensor<1xi64> + %9808 = stablehlo.dynamic_reshape %9806, %from_elements_3456 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3457 = tensor.from_elements %9805, %c2_i64 : tensor<2xi64> + %9809 = stablehlo.real_dynamic_slice %9804, %c_24, %from_elements_3457, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3458 = tensor.dim %9809, %c0 : tensor + %9810 = arith.index_cast %dim_3458 : index to i64 + %from_elements_3459 = tensor.from_elements %9810 : tensor<1xi64> + %9811 = stablehlo.dynamic_reshape %9809, %from_elements_3459 : (tensor, tensor<1xi64>) -> tensor + %dim_3460 = tensor.dim %9811, %c0 : tensor + %9812 = arith.index_cast %dim_3460 : index to i64 + %from_elements_3461 = tensor.from_elements %9812, %c1_i64 : tensor<2xi64> + %9813 = stablehlo.dynamic_reshape %9811, %from_elements_3461 : (tensor, tensor<2xi64>) -> tensor + %dim_3462 = tensor.dim %9813, %c0 : tensor + %9814 = arith.index_cast %dim_3462 : index to i64 + %from_elements_3463 = tensor.from_elements %c1_i64, %9814, %c4096_i64 : tensor<3xi64> + %9815 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3463, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3464 = tensor.dim %9815, %c1 : tensor<1x?x4096xi64> + %9816 = arith.index_cast %dim_3464 : index to i64 + %from_elements_3465 = tensor.from_elements %c1_i64, %9816, %c4096_i64, %c1_i64 : tensor<4xi64> + %9817 = stablehlo.dynamic_reshape %9815, %from_elements_3465 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9818 = stablehlo.dynamic_broadcast_in_dim %9813, %from_elements_3463, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3466 = tensor.dim %9818, %c1 : tensor<1x?x4096xi64> + %9819 = arith.index_cast %dim_3466 : index to i64 + %from_elements_3467 = tensor.from_elements %c1_i64, %9819, %c4096_i64, %c1_i64 : tensor<4xi64> + %9820 = stablehlo.dynamic_reshape %9818, %from_elements_3467 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9821 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3463, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3468 = tensor.dim %9821, %c1 : tensor<1x?x4096xi64> + %9822 = arith.index_cast %dim_3468 : index to i64 + %from_elements_3469 = tensor.from_elements %c1_i64, %9822, %c4096_i64, %c1_i64 : tensor<4xi64> + %9823 = stablehlo.dynamic_reshape %9821, %from_elements_3469 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9824 = stablehlo.concatenate %9817, %9820, %9823, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9825 = "stablehlo.gather"(%9369, %9824) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9826 = shape.shape_of %9825 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9827 = shape.num_elements %9826 : tensor<3xindex> -> index + %9828 = stablehlo.compute_reshape_shape %9827, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %9829 = stablehlo.dynamic_reshape %9825, %9828 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %9830 = stablehlo.dot %9829, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %9831 = stablehlo.logistic %9830 : tensor + %9832 = shape.shape_of %9831 : tensor -> tensor<2xindex> + %9833 = shape.shape_of %9830 : tensor -> tensor<2xindex> + %9834 = shape.cstr_broadcastable %9832, %9833 : tensor<2xindex>, tensor<2xindex> + %9835 = shape.assuming %9834 -> (tensor) { + %19688 = shape.broadcast %9832, %9833 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9831, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9830, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9836 = shape.shape_of %9835 : tensor -> tensor<2xindex> + %9837 = shape.cstr_broadcastable %9836, %9833 : tensor<2xindex>, tensor<2xindex> + %9838 = shape.assuming %9837 -> (tensor) { + %19688 = shape.broadcast %9836, %9833 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9835, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9830, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9839 = stablehlo.dot %9838, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3470 = tensor.dim %9811, %c0 : tensor + %9840 = arith.index_cast %dim_3470 : index to i64 + %from_elements_3471 = tensor.from_elements %9840, %c1_i64 : tensor<2xi64> + %9841 = stablehlo.dynamic_reshape %9811, %from_elements_3471 : (tensor, tensor<2xi64>) -> tensor + %dim_3472 = tensor.dim %9808, %c0 : tensor + %9842 = arith.index_cast %dim_3472 : index to i64 + %from_elements_3473 = tensor.from_elements %9842, %c1_i64 : tensor<2xi64> + %9843 = stablehlo.dynamic_reshape %9808, %from_elements_3473 : (tensor, tensor<2xi64>) -> tensor + %9844 = stablehlo.concatenate %9841, %9843, dim = 1 : (tensor, tensor) -> tensor + %9845 = "stablehlo.gather"(%9398, %9844) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %9846 = shape.shape_of %9839 : tensor -> tensor<2xindex> + %9847 = shape.shape_of %9845 : tensor -> tensor<2xindex> + %9848 = shape.cstr_broadcastable %9846, %9847 : tensor<2xindex>, tensor<2xindex> + %9849 = shape.assuming %9848 -> (tensor) { + %19688 = shape.broadcast %9846, %9847 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %9839, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %9845, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %9850 = shape.shape_of %9849 : tensor -> tensor<2xindex> + %9851 = stablehlo.dynamic_broadcast_in_dim %9849, %9850, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %9852 = stablehlo.dynamic_broadcast_in_dim %213, %9850, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9853 = stablehlo.multiply %9851, %9852 : tensor + %dim_3474 = tensor.dim %9813, %c0 : tensor + %9854 = arith.index_cast %dim_3474 : index to i64 + %dim_3475 = tensor.dim %9849, %c0 : tensor + %9855 = arith.index_cast %dim_3475 : index to i64 + %9856 = arith.maxsi %9854, %9855 : i64 + %9857 = arith.index_cast %9856 : i64 to index + %from_elements_3476 = tensor.from_elements %9857, %c4096 : tensor<2xindex> + %9858 = stablehlo.dynamic_broadcast_in_dim %9813, %from_elements_3476, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3477 = tensor.dim %9858, %c0 : tensor + %9859 = arith.index_cast %dim_3477 : index to i64 + %from_elements_3478 = tensor.from_elements %9859, %c4096_i64 : tensor<2xi64> + %9860 = stablehlo.real_dynamic_slice %9853, %c_22, %from_elements_3478, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3479 = tensor.from_elements %9859, %c4096_i64, %c1_i64 : tensor<3xi64> + %9861 = stablehlo.dynamic_reshape %9858, %from_elements_3479 : (tensor, tensor<3xi64>) -> tensor + %9862 = stablehlo.dynamic_iota %from_elements_3479, dim = 1 : (tensor<3xi64>) -> tensor + %9863 = stablehlo.concatenate %9861, %9862, dim = 2 : (tensor, tensor) -> tensor + %9864 = "stablehlo.scatter"(%9801, %9863, %9860) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %9865 = stablehlo.reshape %9864 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %9866 = stablehlo.add %9331, %9865 : tensor<3x1x4096xf32> + %9867 = stablehlo.broadcast_in_dim %9866, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %9868 = stablehlo.power %9867, %15 : tensor<3x1x4096xf32> + %9869 = stablehlo.reduce(%9868 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %9870 = stablehlo.reshape %9869 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %9871 = stablehlo.broadcast_in_dim %9870, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %9872 = stablehlo.divide %9871, %21 : tensor<3x1x1xf32> + %9873 = stablehlo.broadcast_in_dim %9872, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %9874 = stablehlo.add %9873, %25 : tensor<3x1x1xf32> + %9875 = stablehlo.rsqrt %9874 : tensor<3x1x1xf32> + %9876 = stablehlo.broadcast_in_dim %9875, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %9877 = stablehlo.multiply %9867, %9876 : tensor<3x1x4096xf32> + %9878 = stablehlo.broadcast_in_dim %9877, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %9879 = stablehlo.multiply %9878, %31 : tensor<3x1x4096xf32> + %9880 = stablehlo.reshape %9879 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %9881 = stablehlo.dot %9880, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %9882 = stablehlo.reshape %9881 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %9883 = stablehlo.dot %9880, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %9884 = stablehlo.reshape %9883 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %9885 = stablehlo.reshape %9882 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %9886 = stablehlo.transpose %9885, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %9887 = stablehlo.reshape %9884 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %9888 = stablehlo.transpose %9887, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %9889 = stablehlo.slice %arg32 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %9890 = stablehlo.slice %arg33 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %9891 = "stablehlo.gather"(%9889, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %9892 = stablehlo.reshape %9891 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %9893 = "stablehlo.gather"(%9890, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %9894 = stablehlo.reshape %9893 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %9895 = stablehlo.broadcast_in_dim %9886, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %9896 = stablehlo.broadcast_in_dim %9892, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %9897 = stablehlo.multiply %9895, %9896 : tensor<3x32x1x128xf32> + %9898 = stablehlo.slice %9886 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %9899 = stablehlo.slice %9886 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %9900 = stablehlo.negate %9899 : tensor<3x32x1x64xf32> + %9901 = stablehlo.concatenate %9900, %9898, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %9902 = stablehlo.broadcast_in_dim %9901, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %9903 = stablehlo.broadcast_in_dim %9894, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %9904 = stablehlo.multiply %9902, %9903 : tensor<3x32x1x128xf32> + %9905 = stablehlo.add %9897, %9904 : tensor<3x32x1x128xf32> + %9906 = stablehlo.broadcast_in_dim %9888, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %9907 = stablehlo.broadcast_in_dim %9892, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %9908 = stablehlo.multiply %9906, %9907 : tensor<3x8x1x128xf32> + %9909 = stablehlo.slice %9888 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %9910 = stablehlo.slice %9888 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %9911 = stablehlo.negate %9910 : tensor<3x8x1x64xf32> + %9912 = stablehlo.concatenate %9911, %9909, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %9913 = stablehlo.broadcast_in_dim %9912, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %9914 = stablehlo.broadcast_in_dim %9894, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %9915 = stablehlo.multiply %9913, %9914 : tensor<3x8x1x128xf32> + %9916 = stablehlo.add %9908, %9915 : tensor<3x8x1x128xf32> + %9917 = stablehlo.concatenate %arg97, %9916, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %9918 = stablehlo.concatenate %arg98, %9888, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %9919 = stablehlo.reshape %9917 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %9920 = stablehlo.broadcast_in_dim %9919, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %9921 = stablehlo.reshape %9920 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %9922 = stablehlo.reshape %9918 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %9923 = stablehlo.broadcast_in_dim %9922, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %9924 = stablehlo.reshape %9923 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %9925 = stablehlo.transpose %9921, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %9926 = stablehlo.reshape %9905 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %9927 = stablehlo.reshape %9925 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %9928 = stablehlo.broadcast_in_dim %9927, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %9929 = stablehlo.dot_general %9926, %9928, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %9930 = stablehlo.reshape %9929 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %9931 = stablehlo.broadcast_in_dim %9930, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %9932 = stablehlo.divide %9931, %89 : tensor<3x32x1x8xf32> + %9933 = stablehlo.custom_call @byteir.softmax(%9932) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %9934 = stablehlo.reshape %9933 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %9935 = stablehlo.reshape %9924 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %9936 = stablehlo.broadcast_in_dim %9935, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %9937 = stablehlo.dot_general %9934, %9936, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %9938 = stablehlo.reshape %9937 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %9939 = stablehlo.transpose %9938, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %9940 = stablehlo.reshape %9939 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %9941 = stablehlo.reshape %9940 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %9942 = stablehlo.dot %9941, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %9943 = stablehlo.reshape %9942 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %9944 = stablehlo.add %9866, %9943 : tensor<3x1x4096xf32> + %9945 = stablehlo.broadcast_in_dim %9944, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %9946 = stablehlo.power %9945, %15 : tensor<3x1x4096xf32> + %9947 = stablehlo.reduce(%9946 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %9948 = stablehlo.reshape %9947 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %9949 = stablehlo.broadcast_in_dim %9948, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %9950 = stablehlo.divide %9949, %21 : tensor<3x1x1xf32> + %9951 = stablehlo.broadcast_in_dim %9950, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %9952 = stablehlo.add %9951, %25 : tensor<3x1x1xf32> + %9953 = stablehlo.rsqrt %9952 : tensor<3x1x1xf32> + %9954 = stablehlo.broadcast_in_dim %9953, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %9955 = stablehlo.multiply %9945, %9954 : tensor<3x1x4096xf32> + %9956 = stablehlo.broadcast_in_dim %9955, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %9957 = stablehlo.multiply %9956, %31 : tensor<3x1x4096xf32> + %9958 = stablehlo.reshape %9957 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %9959 = stablehlo.dot %9958, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %9960 = stablehlo.custom_call @byteir.softmax(%9959) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %9961:2 = stablehlo.custom_call @byteir.top_k(%9960) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %9962 = stablehlo.reduce(%9961#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %9963 = stablehlo.reshape %9962 : (tensor<3xf32>) -> tensor<3x1xf32> + %9964 = stablehlo.broadcast_in_dim %9961#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %9965 = stablehlo.broadcast_in_dim %9963, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %9966 = stablehlo.divide %9964, %9965 : tensor<3x2xf32> + %9967 = stablehlo.reshape %9961#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %9968 = stablehlo.broadcast_in_dim %9967, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %9969 = stablehlo.compare EQ, %9968, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %9970 = stablehlo.convert %9969 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %9971 = stablehlo.transpose %9970, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %9972 = stablehlo.slice %9971 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %9973 = stablehlo.reshape %9972 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %9974 = stablehlo.custom_call @byteir.non_zero(%9973) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3480 = tensor.dim %9974, %c0 : tensor + %9975 = arith.index_cast %dim_3480 : index to i64 + %from_elements_3481 = tensor.from_elements %9975, %c1_i64 : tensor<2xi64> + %9976 = stablehlo.real_dynamic_slice %9974, %c_22, %from_elements_3481, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3482 = tensor.dim %9976, %c0 : tensor + %9977 = arith.index_cast %dim_3482 : index to i64 + %from_elements_3483 = tensor.from_elements %9977 : tensor<1xi64> + %9978 = stablehlo.dynamic_reshape %9976, %from_elements_3483 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3484 = tensor.from_elements %9975, %c2_i64 : tensor<2xi64> + %9979 = stablehlo.real_dynamic_slice %9974, %c_24, %from_elements_3484, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3485 = tensor.dim %9979, %c0 : tensor + %9980 = arith.index_cast %dim_3485 : index to i64 + %from_elements_3486 = tensor.from_elements %9980 : tensor<1xi64> + %9981 = stablehlo.dynamic_reshape %9979, %from_elements_3486 : (tensor, tensor<1xi64>) -> tensor + %9982 = stablehlo.reshape %9958 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_3487 = tensor.dim %9981, %c0 : tensor + %9983 = arith.index_cast %dim_3487 : index to i64 + %from_elements_3488 = tensor.from_elements %9983, %c1_i64 : tensor<2xi64> + %9984 = stablehlo.dynamic_reshape %9981, %from_elements_3488 : (tensor, tensor<2xi64>) -> tensor + %dim_3489 = tensor.dim %9984, %c0 : tensor + %9985 = arith.index_cast %dim_3489 : index to i64 + %from_elements_3490 = tensor.from_elements %c1_i64, %9985, %c4096_i64 : tensor<3xi64> + %9986 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3490, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3491 = tensor.dim %9986, %c1 : tensor<1x?x4096xi64> + %9987 = arith.index_cast %dim_3491 : index to i64 + %from_elements_3492 = tensor.from_elements %c1_i64, %9987, %c4096_i64, %c1_i64 : tensor<4xi64> + %9988 = stablehlo.dynamic_reshape %9986, %from_elements_3492 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9989 = stablehlo.dynamic_broadcast_in_dim %9984, %from_elements_3490, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3493 = tensor.dim %9989, %c1 : tensor<1x?x4096xi64> + %9990 = arith.index_cast %dim_3493 : index to i64 + %from_elements_3494 = tensor.from_elements %c1_i64, %9990, %c4096_i64, %c1_i64 : tensor<4xi64> + %9991 = stablehlo.dynamic_reshape %9989, %from_elements_3494 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9992 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3490, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3495 = tensor.dim %9992, %c1 : tensor<1x?x4096xi64> + %9993 = arith.index_cast %dim_3495 : index to i64 + %from_elements_3496 = tensor.from_elements %c1_i64, %9993, %c4096_i64, %c1_i64 : tensor<4xi64> + %9994 = stablehlo.dynamic_reshape %9992, %from_elements_3496 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %9995 = stablehlo.concatenate %9988, %9991, %9994, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %9996 = "stablehlo.gather"(%9982, %9995) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %9997 = shape.shape_of %9996 : tensor<1x?x4096xf32> -> tensor<3xindex> + %9998 = shape.num_elements %9997 : tensor<3xindex> -> index + %9999 = stablehlo.compute_reshape_shape %9998, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10000 = stablehlo.dynamic_reshape %9996, %9999 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10001 = stablehlo.dot %10000, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10002 = stablehlo.logistic %10001 : tensor + %10003 = shape.shape_of %10002 : tensor -> tensor<2xindex> + %10004 = shape.shape_of %10001 : tensor -> tensor<2xindex> + %10005 = shape.cstr_broadcastable %10003, %10004 : tensor<2xindex>, tensor<2xindex> + %10006 = shape.assuming %10005 -> (tensor) { + %19688 = shape.broadcast %10003, %10004 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10002, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10001, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10007 = shape.shape_of %10006 : tensor -> tensor<2xindex> + %10008 = shape.cstr_broadcastable %10007, %10004 : tensor<2xindex>, tensor<2xindex> + %10009 = shape.assuming %10008 -> (tensor) { + %19688 = shape.broadcast %10007, %10004 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10006, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10001, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10010 = stablehlo.dot %10009, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %10011 = stablehlo.reshape %9966 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_3497 = tensor.dim %9981, %c0 : tensor + %10012 = arith.index_cast %dim_3497 : index to i64 + %from_elements_3498 = tensor.from_elements %10012, %c1_i64 : tensor<2xi64> + %10013 = stablehlo.dynamic_reshape %9981, %from_elements_3498 : (tensor, tensor<2xi64>) -> tensor + %dim_3499 = tensor.dim %9978, %c0 : tensor + %10014 = arith.index_cast %dim_3499 : index to i64 + %from_elements_3500 = tensor.from_elements %10014, %c1_i64 : tensor<2xi64> + %10015 = stablehlo.dynamic_reshape %9978, %from_elements_3500 : (tensor, tensor<2xi64>) -> tensor + %10016 = stablehlo.concatenate %10013, %10015, dim = 1 : (tensor, tensor) -> tensor + %10017 = "stablehlo.gather"(%10011, %10016) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10018 = shape.shape_of %10010 : tensor -> tensor<2xindex> + %10019 = shape.shape_of %10017 : tensor -> tensor<2xindex> + %10020 = shape.cstr_broadcastable %10018, %10019 : tensor<2xindex>, tensor<2xindex> + %10021 = shape.assuming %10020 -> (tensor) { + %19688 = shape.broadcast %10018, %10019 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10010, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10017, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10022 = shape.shape_of %10021 : tensor -> tensor<2xindex> + %10023 = stablehlo.dynamic_broadcast_in_dim %10021, %10022, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10024 = stablehlo.dynamic_broadcast_in_dim %213, %10022, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10025 = stablehlo.multiply %10023, %10024 : tensor + %dim_3501 = tensor.dim %9984, %c0 : tensor + %10026 = arith.index_cast %dim_3501 : index to i64 + %dim_3502 = tensor.dim %10021, %c0 : tensor + %10027 = arith.index_cast %dim_3502 : index to i64 + %10028 = arith.maxsi %10026, %10027 : i64 + %10029 = arith.index_cast %10028 : i64 to index + %from_elements_3503 = tensor.from_elements %10029, %c4096 : tensor<2xindex> + %10030 = stablehlo.dynamic_broadcast_in_dim %9984, %from_elements_3503, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3504 = tensor.dim %10030, %c0 : tensor + %10031 = arith.index_cast %dim_3504 : index to i64 + %from_elements_3505 = tensor.from_elements %10031, %c4096_i64 : tensor<2xi64> + %10032 = stablehlo.real_dynamic_slice %10025, %c_22, %from_elements_3505, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3506 = tensor.from_elements %10031, %c4096_i64, %c1_i64 : tensor<3xi64> + %10033 = stablehlo.dynamic_reshape %10030, %from_elements_3506 : (tensor, tensor<3xi64>) -> tensor + %10034 = stablehlo.dynamic_iota %from_elements_3506, dim = 1 : (tensor<3xi64>) -> tensor + %10035 = stablehlo.concatenate %10033, %10034, dim = 2 : (tensor, tensor) -> tensor + %10036 = "stablehlo.scatter"(%cst_2, %10035, %10032) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10037 = stablehlo.slice %9971 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10038 = stablehlo.reshape %10037 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10039 = stablehlo.custom_call @byteir.non_zero(%10038) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3507 = tensor.dim %10039, %c0 : tensor + %10040 = arith.index_cast %dim_3507 : index to i64 + %from_elements_3508 = tensor.from_elements %10040, %c1_i64 : tensor<2xi64> + %10041 = stablehlo.real_dynamic_slice %10039, %c_22, %from_elements_3508, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3509 = tensor.dim %10041, %c0 : tensor + %10042 = arith.index_cast %dim_3509 : index to i64 + %from_elements_3510 = tensor.from_elements %10042 : tensor<1xi64> + %10043 = stablehlo.dynamic_reshape %10041, %from_elements_3510 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3511 = tensor.from_elements %10040, %c2_i64 : tensor<2xi64> + %10044 = stablehlo.real_dynamic_slice %10039, %c_24, %from_elements_3511, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3512 = tensor.dim %10044, %c0 : tensor + %10045 = arith.index_cast %dim_3512 : index to i64 + %from_elements_3513 = tensor.from_elements %10045 : tensor<1xi64> + %10046 = stablehlo.dynamic_reshape %10044, %from_elements_3513 : (tensor, tensor<1xi64>) -> tensor + %dim_3514 = tensor.dim %10046, %c0 : tensor + %10047 = arith.index_cast %dim_3514 : index to i64 + %from_elements_3515 = tensor.from_elements %10047, %c1_i64 : tensor<2xi64> + %10048 = stablehlo.dynamic_reshape %10046, %from_elements_3515 : (tensor, tensor<2xi64>) -> tensor + %dim_3516 = tensor.dim %10048, %c0 : tensor + %10049 = arith.index_cast %dim_3516 : index to i64 + %from_elements_3517 = tensor.from_elements %c1_i64, %10049, %c4096_i64 : tensor<3xi64> + %10050 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3517, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3518 = tensor.dim %10050, %c1 : tensor<1x?x4096xi64> + %10051 = arith.index_cast %dim_3518 : index to i64 + %from_elements_3519 = tensor.from_elements %c1_i64, %10051, %c4096_i64, %c1_i64 : tensor<4xi64> + %10052 = stablehlo.dynamic_reshape %10050, %from_elements_3519 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10053 = stablehlo.dynamic_broadcast_in_dim %10048, %from_elements_3517, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3520 = tensor.dim %10053, %c1 : tensor<1x?x4096xi64> + %10054 = arith.index_cast %dim_3520 : index to i64 + %from_elements_3521 = tensor.from_elements %c1_i64, %10054, %c4096_i64, %c1_i64 : tensor<4xi64> + %10055 = stablehlo.dynamic_reshape %10053, %from_elements_3521 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10056 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3517, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3522 = tensor.dim %10056, %c1 : tensor<1x?x4096xi64> + %10057 = arith.index_cast %dim_3522 : index to i64 + %from_elements_3523 = tensor.from_elements %c1_i64, %10057, %c4096_i64, %c1_i64 : tensor<4xi64> + %10058 = stablehlo.dynamic_reshape %10056, %from_elements_3523 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10059 = stablehlo.concatenate %10052, %10055, %10058, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10060 = "stablehlo.gather"(%9982, %10059) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10061 = shape.shape_of %10060 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10062 = shape.num_elements %10061 : tensor<3xindex> -> index + %10063 = stablehlo.compute_reshape_shape %10062, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10064 = stablehlo.dynamic_reshape %10060, %10063 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10065 = stablehlo.dot %10064, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10066 = stablehlo.logistic %10065 : tensor + %10067 = shape.shape_of %10066 : tensor -> tensor<2xindex> + %10068 = shape.shape_of %10065 : tensor -> tensor<2xindex> + %10069 = shape.cstr_broadcastable %10067, %10068 : tensor<2xindex>, tensor<2xindex> + %10070 = shape.assuming %10069 -> (tensor) { + %19688 = shape.broadcast %10067, %10068 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10066, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10065, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10071 = shape.shape_of %10070 : tensor -> tensor<2xindex> + %10072 = shape.cstr_broadcastable %10071, %10068 : tensor<2xindex>, tensor<2xindex> + %10073 = shape.assuming %10072 -> (tensor) { + %19688 = shape.broadcast %10071, %10068 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10070, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10065, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10074 = stablehlo.dot %10073, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3524 = tensor.dim %10046, %c0 : tensor + %10075 = arith.index_cast %dim_3524 : index to i64 + %from_elements_3525 = tensor.from_elements %10075, %c1_i64 : tensor<2xi64> + %10076 = stablehlo.dynamic_reshape %10046, %from_elements_3525 : (tensor, tensor<2xi64>) -> tensor + %dim_3526 = tensor.dim %10043, %c0 : tensor + %10077 = arith.index_cast %dim_3526 : index to i64 + %from_elements_3527 = tensor.from_elements %10077, %c1_i64 : tensor<2xi64> + %10078 = stablehlo.dynamic_reshape %10043, %from_elements_3527 : (tensor, tensor<2xi64>) -> tensor + %10079 = stablehlo.concatenate %10076, %10078, dim = 1 : (tensor, tensor) -> tensor + %10080 = "stablehlo.gather"(%10011, %10079) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10081 = shape.shape_of %10074 : tensor -> tensor<2xindex> + %10082 = shape.shape_of %10080 : tensor -> tensor<2xindex> + %10083 = shape.cstr_broadcastable %10081, %10082 : tensor<2xindex>, tensor<2xindex> + %10084 = shape.assuming %10083 -> (tensor) { + %19688 = shape.broadcast %10081, %10082 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10074, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10080, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10085 = shape.shape_of %10084 : tensor -> tensor<2xindex> + %10086 = stablehlo.dynamic_broadcast_in_dim %10084, %10085, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10087 = stablehlo.dynamic_broadcast_in_dim %213, %10085, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10088 = stablehlo.multiply %10086, %10087 : tensor + %dim_3528 = tensor.dim %10048, %c0 : tensor + %10089 = arith.index_cast %dim_3528 : index to i64 + %dim_3529 = tensor.dim %10084, %c0 : tensor + %10090 = arith.index_cast %dim_3529 : index to i64 + %10091 = arith.maxsi %10089, %10090 : i64 + %10092 = arith.index_cast %10091 : i64 to index + %from_elements_3530 = tensor.from_elements %10092, %c4096 : tensor<2xindex> + %10093 = stablehlo.dynamic_broadcast_in_dim %10048, %from_elements_3530, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3531 = tensor.dim %10093, %c0 : tensor + %10094 = arith.index_cast %dim_3531 : index to i64 + %from_elements_3532 = tensor.from_elements %10094, %c4096_i64 : tensor<2xi64> + %10095 = stablehlo.real_dynamic_slice %10088, %c_22, %from_elements_3532, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3533 = tensor.from_elements %10094, %c4096_i64, %c1_i64 : tensor<3xi64> + %10096 = stablehlo.dynamic_reshape %10093, %from_elements_3533 : (tensor, tensor<3xi64>) -> tensor + %10097 = stablehlo.dynamic_iota %from_elements_3533, dim = 1 : (tensor<3xi64>) -> tensor + %10098 = stablehlo.concatenate %10096, %10097, dim = 2 : (tensor, tensor) -> tensor + %10099 = "stablehlo.scatter"(%10036, %10098, %10095) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10100 = stablehlo.slice %9971 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10101 = stablehlo.reshape %10100 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10102 = stablehlo.custom_call @byteir.non_zero(%10101) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3534 = tensor.dim %10102, %c0 : tensor + %10103 = arith.index_cast %dim_3534 : index to i64 + %from_elements_3535 = tensor.from_elements %10103, %c1_i64 : tensor<2xi64> + %10104 = stablehlo.real_dynamic_slice %10102, %c_22, %from_elements_3535, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3536 = tensor.dim %10104, %c0 : tensor + %10105 = arith.index_cast %dim_3536 : index to i64 + %from_elements_3537 = tensor.from_elements %10105 : tensor<1xi64> + %10106 = stablehlo.dynamic_reshape %10104, %from_elements_3537 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3538 = tensor.from_elements %10103, %c2_i64 : tensor<2xi64> + %10107 = stablehlo.real_dynamic_slice %10102, %c_24, %from_elements_3538, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3539 = tensor.dim %10107, %c0 : tensor + %10108 = arith.index_cast %dim_3539 : index to i64 + %from_elements_3540 = tensor.from_elements %10108 : tensor<1xi64> + %10109 = stablehlo.dynamic_reshape %10107, %from_elements_3540 : (tensor, tensor<1xi64>) -> tensor + %dim_3541 = tensor.dim %10109, %c0 : tensor + %10110 = arith.index_cast %dim_3541 : index to i64 + %from_elements_3542 = tensor.from_elements %10110, %c1_i64 : tensor<2xi64> + %10111 = stablehlo.dynamic_reshape %10109, %from_elements_3542 : (tensor, tensor<2xi64>) -> tensor + %dim_3543 = tensor.dim %10111, %c0 : tensor + %10112 = arith.index_cast %dim_3543 : index to i64 + %from_elements_3544 = tensor.from_elements %c1_i64, %10112, %c4096_i64 : tensor<3xi64> + %10113 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3544, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3545 = tensor.dim %10113, %c1 : tensor<1x?x4096xi64> + %10114 = arith.index_cast %dim_3545 : index to i64 + %from_elements_3546 = tensor.from_elements %c1_i64, %10114, %c4096_i64, %c1_i64 : tensor<4xi64> + %10115 = stablehlo.dynamic_reshape %10113, %from_elements_3546 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10116 = stablehlo.dynamic_broadcast_in_dim %10111, %from_elements_3544, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3547 = tensor.dim %10116, %c1 : tensor<1x?x4096xi64> + %10117 = arith.index_cast %dim_3547 : index to i64 + %from_elements_3548 = tensor.from_elements %c1_i64, %10117, %c4096_i64, %c1_i64 : tensor<4xi64> + %10118 = stablehlo.dynamic_reshape %10116, %from_elements_3548 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10119 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3544, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3549 = tensor.dim %10119, %c1 : tensor<1x?x4096xi64> + %10120 = arith.index_cast %dim_3549 : index to i64 + %from_elements_3550 = tensor.from_elements %c1_i64, %10120, %c4096_i64, %c1_i64 : tensor<4xi64> + %10121 = stablehlo.dynamic_reshape %10119, %from_elements_3550 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10122 = stablehlo.concatenate %10115, %10118, %10121, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10123 = "stablehlo.gather"(%9982, %10122) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10124 = shape.shape_of %10123 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10125 = shape.num_elements %10124 : tensor<3xindex> -> index + %10126 = stablehlo.compute_reshape_shape %10125, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10127 = stablehlo.dynamic_reshape %10123, %10126 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10128 = stablehlo.dot %10127, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10129 = stablehlo.logistic %10128 : tensor + %10130 = shape.shape_of %10129 : tensor -> tensor<2xindex> + %10131 = shape.shape_of %10128 : tensor -> tensor<2xindex> + %10132 = shape.cstr_broadcastable %10130, %10131 : tensor<2xindex>, tensor<2xindex> + %10133 = shape.assuming %10132 -> (tensor) { + %19688 = shape.broadcast %10130, %10131 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10129, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10128, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10134 = shape.shape_of %10133 : tensor -> tensor<2xindex> + %10135 = shape.cstr_broadcastable %10134, %10131 : tensor<2xindex>, tensor<2xindex> + %10136 = shape.assuming %10135 -> (tensor) { + %19688 = shape.broadcast %10134, %10131 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10133, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10128, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10137 = stablehlo.dot %10136, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3551 = tensor.dim %10109, %c0 : tensor + %10138 = arith.index_cast %dim_3551 : index to i64 + %from_elements_3552 = tensor.from_elements %10138, %c1_i64 : tensor<2xi64> + %10139 = stablehlo.dynamic_reshape %10109, %from_elements_3552 : (tensor, tensor<2xi64>) -> tensor + %dim_3553 = tensor.dim %10106, %c0 : tensor + %10140 = arith.index_cast %dim_3553 : index to i64 + %from_elements_3554 = tensor.from_elements %10140, %c1_i64 : tensor<2xi64> + %10141 = stablehlo.dynamic_reshape %10106, %from_elements_3554 : (tensor, tensor<2xi64>) -> tensor + %10142 = stablehlo.concatenate %10139, %10141, dim = 1 : (tensor, tensor) -> tensor + %10143 = "stablehlo.gather"(%10011, %10142) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10144 = shape.shape_of %10137 : tensor -> tensor<2xindex> + %10145 = shape.shape_of %10143 : tensor -> tensor<2xindex> + %10146 = shape.cstr_broadcastable %10144, %10145 : tensor<2xindex>, tensor<2xindex> + %10147 = shape.assuming %10146 -> (tensor) { + %19688 = shape.broadcast %10144, %10145 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10137, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10143, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10148 = shape.shape_of %10147 : tensor -> tensor<2xindex> + %10149 = stablehlo.dynamic_broadcast_in_dim %10147, %10148, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10150 = stablehlo.dynamic_broadcast_in_dim %213, %10148, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10151 = stablehlo.multiply %10149, %10150 : tensor + %dim_3555 = tensor.dim %10111, %c0 : tensor + %10152 = arith.index_cast %dim_3555 : index to i64 + %dim_3556 = tensor.dim %10147, %c0 : tensor + %10153 = arith.index_cast %dim_3556 : index to i64 + %10154 = arith.maxsi %10152, %10153 : i64 + %10155 = arith.index_cast %10154 : i64 to index + %from_elements_3557 = tensor.from_elements %10155, %c4096 : tensor<2xindex> + %10156 = stablehlo.dynamic_broadcast_in_dim %10111, %from_elements_3557, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3558 = tensor.dim %10156, %c0 : tensor + %10157 = arith.index_cast %dim_3558 : index to i64 + %from_elements_3559 = tensor.from_elements %10157, %c4096_i64 : tensor<2xi64> + %10158 = stablehlo.real_dynamic_slice %10151, %c_22, %from_elements_3559, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3560 = tensor.from_elements %10157, %c4096_i64, %c1_i64 : tensor<3xi64> + %10159 = stablehlo.dynamic_reshape %10156, %from_elements_3560 : (tensor, tensor<3xi64>) -> tensor + %10160 = stablehlo.dynamic_iota %from_elements_3560, dim = 1 : (tensor<3xi64>) -> tensor + %10161 = stablehlo.concatenate %10159, %10160, dim = 2 : (tensor, tensor) -> tensor + %10162 = "stablehlo.scatter"(%10099, %10161, %10158) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10163 = stablehlo.slice %9971 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10164 = stablehlo.reshape %10163 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10165 = stablehlo.custom_call @byteir.non_zero(%10164) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3561 = tensor.dim %10165, %c0 : tensor + %10166 = arith.index_cast %dim_3561 : index to i64 + %from_elements_3562 = tensor.from_elements %10166, %c1_i64 : tensor<2xi64> + %10167 = stablehlo.real_dynamic_slice %10165, %c_22, %from_elements_3562, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3563 = tensor.dim %10167, %c0 : tensor + %10168 = arith.index_cast %dim_3563 : index to i64 + %from_elements_3564 = tensor.from_elements %10168 : tensor<1xi64> + %10169 = stablehlo.dynamic_reshape %10167, %from_elements_3564 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3565 = tensor.from_elements %10166, %c2_i64 : tensor<2xi64> + %10170 = stablehlo.real_dynamic_slice %10165, %c_24, %from_elements_3565, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3566 = tensor.dim %10170, %c0 : tensor + %10171 = arith.index_cast %dim_3566 : index to i64 + %from_elements_3567 = tensor.from_elements %10171 : tensor<1xi64> + %10172 = stablehlo.dynamic_reshape %10170, %from_elements_3567 : (tensor, tensor<1xi64>) -> tensor + %dim_3568 = tensor.dim %10172, %c0 : tensor + %10173 = arith.index_cast %dim_3568 : index to i64 + %from_elements_3569 = tensor.from_elements %10173, %c1_i64 : tensor<2xi64> + %10174 = stablehlo.dynamic_reshape %10172, %from_elements_3569 : (tensor, tensor<2xi64>) -> tensor + %dim_3570 = tensor.dim %10174, %c0 : tensor + %10175 = arith.index_cast %dim_3570 : index to i64 + %from_elements_3571 = tensor.from_elements %c1_i64, %10175, %c4096_i64 : tensor<3xi64> + %10176 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3571, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3572 = tensor.dim %10176, %c1 : tensor<1x?x4096xi64> + %10177 = arith.index_cast %dim_3572 : index to i64 + %from_elements_3573 = tensor.from_elements %c1_i64, %10177, %c4096_i64, %c1_i64 : tensor<4xi64> + %10178 = stablehlo.dynamic_reshape %10176, %from_elements_3573 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10179 = stablehlo.dynamic_broadcast_in_dim %10174, %from_elements_3571, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3574 = tensor.dim %10179, %c1 : tensor<1x?x4096xi64> + %10180 = arith.index_cast %dim_3574 : index to i64 + %from_elements_3575 = tensor.from_elements %c1_i64, %10180, %c4096_i64, %c1_i64 : tensor<4xi64> + %10181 = stablehlo.dynamic_reshape %10179, %from_elements_3575 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10182 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3571, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3576 = tensor.dim %10182, %c1 : tensor<1x?x4096xi64> + %10183 = arith.index_cast %dim_3576 : index to i64 + %from_elements_3577 = tensor.from_elements %c1_i64, %10183, %c4096_i64, %c1_i64 : tensor<4xi64> + %10184 = stablehlo.dynamic_reshape %10182, %from_elements_3577 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10185 = stablehlo.concatenate %10178, %10181, %10184, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10186 = "stablehlo.gather"(%9982, %10185) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10187 = shape.shape_of %10186 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10188 = shape.num_elements %10187 : tensor<3xindex> -> index + %10189 = stablehlo.compute_reshape_shape %10188, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10190 = stablehlo.dynamic_reshape %10186, %10189 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10191 = stablehlo.dot %10190, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10192 = stablehlo.logistic %10191 : tensor + %10193 = shape.shape_of %10192 : tensor -> tensor<2xindex> + %10194 = shape.shape_of %10191 : tensor -> tensor<2xindex> + %10195 = shape.cstr_broadcastable %10193, %10194 : tensor<2xindex>, tensor<2xindex> + %10196 = shape.assuming %10195 -> (tensor) { + %19688 = shape.broadcast %10193, %10194 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10192, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10191, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10197 = shape.shape_of %10196 : tensor -> tensor<2xindex> + %10198 = shape.cstr_broadcastable %10197, %10194 : tensor<2xindex>, tensor<2xindex> + %10199 = shape.assuming %10198 -> (tensor) { + %19688 = shape.broadcast %10197, %10194 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10196, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10191, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10200 = stablehlo.dot %10199, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3578 = tensor.dim %10172, %c0 : tensor + %10201 = arith.index_cast %dim_3578 : index to i64 + %from_elements_3579 = tensor.from_elements %10201, %c1_i64 : tensor<2xi64> + %10202 = stablehlo.dynamic_reshape %10172, %from_elements_3579 : (tensor, tensor<2xi64>) -> tensor + %dim_3580 = tensor.dim %10169, %c0 : tensor + %10203 = arith.index_cast %dim_3580 : index to i64 + %from_elements_3581 = tensor.from_elements %10203, %c1_i64 : tensor<2xi64> + %10204 = stablehlo.dynamic_reshape %10169, %from_elements_3581 : (tensor, tensor<2xi64>) -> tensor + %10205 = stablehlo.concatenate %10202, %10204, dim = 1 : (tensor, tensor) -> tensor + %10206 = "stablehlo.gather"(%10011, %10205) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10207 = shape.shape_of %10200 : tensor -> tensor<2xindex> + %10208 = shape.shape_of %10206 : tensor -> tensor<2xindex> + %10209 = shape.cstr_broadcastable %10207, %10208 : tensor<2xindex>, tensor<2xindex> + %10210 = shape.assuming %10209 -> (tensor) { + %19688 = shape.broadcast %10207, %10208 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10200, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10206, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10211 = shape.shape_of %10210 : tensor -> tensor<2xindex> + %10212 = stablehlo.dynamic_broadcast_in_dim %10210, %10211, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10213 = stablehlo.dynamic_broadcast_in_dim %213, %10211, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10214 = stablehlo.multiply %10212, %10213 : tensor + %dim_3582 = tensor.dim %10174, %c0 : tensor + %10215 = arith.index_cast %dim_3582 : index to i64 + %dim_3583 = tensor.dim %10210, %c0 : tensor + %10216 = arith.index_cast %dim_3583 : index to i64 + %10217 = arith.maxsi %10215, %10216 : i64 + %10218 = arith.index_cast %10217 : i64 to index + %from_elements_3584 = tensor.from_elements %10218, %c4096 : tensor<2xindex> + %10219 = stablehlo.dynamic_broadcast_in_dim %10174, %from_elements_3584, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3585 = tensor.dim %10219, %c0 : tensor + %10220 = arith.index_cast %dim_3585 : index to i64 + %from_elements_3586 = tensor.from_elements %10220, %c4096_i64 : tensor<2xi64> + %10221 = stablehlo.real_dynamic_slice %10214, %c_22, %from_elements_3586, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3587 = tensor.from_elements %10220, %c4096_i64, %c1_i64 : tensor<3xi64> + %10222 = stablehlo.dynamic_reshape %10219, %from_elements_3587 : (tensor, tensor<3xi64>) -> tensor + %10223 = stablehlo.dynamic_iota %from_elements_3587, dim = 1 : (tensor<3xi64>) -> tensor + %10224 = stablehlo.concatenate %10222, %10223, dim = 2 : (tensor, tensor) -> tensor + %10225 = "stablehlo.scatter"(%10162, %10224, %10221) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10226 = stablehlo.slice %9971 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10227 = stablehlo.reshape %10226 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10228 = stablehlo.custom_call @byteir.non_zero(%10227) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3588 = tensor.dim %10228, %c0 : tensor + %10229 = arith.index_cast %dim_3588 : index to i64 + %from_elements_3589 = tensor.from_elements %10229, %c1_i64 : tensor<2xi64> + %10230 = stablehlo.real_dynamic_slice %10228, %c_22, %from_elements_3589, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3590 = tensor.dim %10230, %c0 : tensor + %10231 = arith.index_cast %dim_3590 : index to i64 + %from_elements_3591 = tensor.from_elements %10231 : tensor<1xi64> + %10232 = stablehlo.dynamic_reshape %10230, %from_elements_3591 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3592 = tensor.from_elements %10229, %c2_i64 : tensor<2xi64> + %10233 = stablehlo.real_dynamic_slice %10228, %c_24, %from_elements_3592, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3593 = tensor.dim %10233, %c0 : tensor + %10234 = arith.index_cast %dim_3593 : index to i64 + %from_elements_3594 = tensor.from_elements %10234 : tensor<1xi64> + %10235 = stablehlo.dynamic_reshape %10233, %from_elements_3594 : (tensor, tensor<1xi64>) -> tensor + %dim_3595 = tensor.dim %10235, %c0 : tensor + %10236 = arith.index_cast %dim_3595 : index to i64 + %from_elements_3596 = tensor.from_elements %10236, %c1_i64 : tensor<2xi64> + %10237 = stablehlo.dynamic_reshape %10235, %from_elements_3596 : (tensor, tensor<2xi64>) -> tensor + %dim_3597 = tensor.dim %10237, %c0 : tensor + %10238 = arith.index_cast %dim_3597 : index to i64 + %from_elements_3598 = tensor.from_elements %c1_i64, %10238, %c4096_i64 : tensor<3xi64> + %10239 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3598, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3599 = tensor.dim %10239, %c1 : tensor<1x?x4096xi64> + %10240 = arith.index_cast %dim_3599 : index to i64 + %from_elements_3600 = tensor.from_elements %c1_i64, %10240, %c4096_i64, %c1_i64 : tensor<4xi64> + %10241 = stablehlo.dynamic_reshape %10239, %from_elements_3600 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10242 = stablehlo.dynamic_broadcast_in_dim %10237, %from_elements_3598, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3601 = tensor.dim %10242, %c1 : tensor<1x?x4096xi64> + %10243 = arith.index_cast %dim_3601 : index to i64 + %from_elements_3602 = tensor.from_elements %c1_i64, %10243, %c4096_i64, %c1_i64 : tensor<4xi64> + %10244 = stablehlo.dynamic_reshape %10242, %from_elements_3602 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10245 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3598, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3603 = tensor.dim %10245, %c1 : tensor<1x?x4096xi64> + %10246 = arith.index_cast %dim_3603 : index to i64 + %from_elements_3604 = tensor.from_elements %c1_i64, %10246, %c4096_i64, %c1_i64 : tensor<4xi64> + %10247 = stablehlo.dynamic_reshape %10245, %from_elements_3604 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10248 = stablehlo.concatenate %10241, %10244, %10247, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10249 = "stablehlo.gather"(%9982, %10248) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10250 = shape.shape_of %10249 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10251 = shape.num_elements %10250 : tensor<3xindex> -> index + %10252 = stablehlo.compute_reshape_shape %10251, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10253 = stablehlo.dynamic_reshape %10249, %10252 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10254 = stablehlo.dot %10253, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10255 = stablehlo.logistic %10254 : tensor + %10256 = shape.shape_of %10255 : tensor -> tensor<2xindex> + %10257 = shape.shape_of %10254 : tensor -> tensor<2xindex> + %10258 = shape.cstr_broadcastable %10256, %10257 : tensor<2xindex>, tensor<2xindex> + %10259 = shape.assuming %10258 -> (tensor) { + %19688 = shape.broadcast %10256, %10257 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10255, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10254, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10260 = shape.shape_of %10259 : tensor -> tensor<2xindex> + %10261 = shape.cstr_broadcastable %10260, %10257 : tensor<2xindex>, tensor<2xindex> + %10262 = shape.assuming %10261 -> (tensor) { + %19688 = shape.broadcast %10260, %10257 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10259, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10254, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10263 = stablehlo.dot %10262, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3605 = tensor.dim %10235, %c0 : tensor + %10264 = arith.index_cast %dim_3605 : index to i64 + %from_elements_3606 = tensor.from_elements %10264, %c1_i64 : tensor<2xi64> + %10265 = stablehlo.dynamic_reshape %10235, %from_elements_3606 : (tensor, tensor<2xi64>) -> tensor + %dim_3607 = tensor.dim %10232, %c0 : tensor + %10266 = arith.index_cast %dim_3607 : index to i64 + %from_elements_3608 = tensor.from_elements %10266, %c1_i64 : tensor<2xi64> + %10267 = stablehlo.dynamic_reshape %10232, %from_elements_3608 : (tensor, tensor<2xi64>) -> tensor + %10268 = stablehlo.concatenate %10265, %10267, dim = 1 : (tensor, tensor) -> tensor + %10269 = "stablehlo.gather"(%10011, %10268) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10270 = shape.shape_of %10263 : tensor -> tensor<2xindex> + %10271 = shape.shape_of %10269 : tensor -> tensor<2xindex> + %10272 = shape.cstr_broadcastable %10270, %10271 : tensor<2xindex>, tensor<2xindex> + %10273 = shape.assuming %10272 -> (tensor) { + %19688 = shape.broadcast %10270, %10271 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10263, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10269, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10274 = shape.shape_of %10273 : tensor -> tensor<2xindex> + %10275 = stablehlo.dynamic_broadcast_in_dim %10273, %10274, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10276 = stablehlo.dynamic_broadcast_in_dim %213, %10274, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10277 = stablehlo.multiply %10275, %10276 : tensor + %dim_3609 = tensor.dim %10237, %c0 : tensor + %10278 = arith.index_cast %dim_3609 : index to i64 + %dim_3610 = tensor.dim %10273, %c0 : tensor + %10279 = arith.index_cast %dim_3610 : index to i64 + %10280 = arith.maxsi %10278, %10279 : i64 + %10281 = arith.index_cast %10280 : i64 to index + %from_elements_3611 = tensor.from_elements %10281, %c4096 : tensor<2xindex> + %10282 = stablehlo.dynamic_broadcast_in_dim %10237, %from_elements_3611, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3612 = tensor.dim %10282, %c0 : tensor + %10283 = arith.index_cast %dim_3612 : index to i64 + %from_elements_3613 = tensor.from_elements %10283, %c4096_i64 : tensor<2xi64> + %10284 = stablehlo.real_dynamic_slice %10277, %c_22, %from_elements_3613, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3614 = tensor.from_elements %10283, %c4096_i64, %c1_i64 : tensor<3xi64> + %10285 = stablehlo.dynamic_reshape %10282, %from_elements_3614 : (tensor, tensor<3xi64>) -> tensor + %10286 = stablehlo.dynamic_iota %from_elements_3614, dim = 1 : (tensor<3xi64>) -> tensor + %10287 = stablehlo.concatenate %10285, %10286, dim = 2 : (tensor, tensor) -> tensor + %10288 = "stablehlo.scatter"(%10225, %10287, %10284) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10289 = stablehlo.slice %9971 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10290 = stablehlo.reshape %10289 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10291 = stablehlo.custom_call @byteir.non_zero(%10290) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3615 = tensor.dim %10291, %c0 : tensor + %10292 = arith.index_cast %dim_3615 : index to i64 + %from_elements_3616 = tensor.from_elements %10292, %c1_i64 : tensor<2xi64> + %10293 = stablehlo.real_dynamic_slice %10291, %c_22, %from_elements_3616, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3617 = tensor.dim %10293, %c0 : tensor + %10294 = arith.index_cast %dim_3617 : index to i64 + %from_elements_3618 = tensor.from_elements %10294 : tensor<1xi64> + %10295 = stablehlo.dynamic_reshape %10293, %from_elements_3618 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3619 = tensor.from_elements %10292, %c2_i64 : tensor<2xi64> + %10296 = stablehlo.real_dynamic_slice %10291, %c_24, %from_elements_3619, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3620 = tensor.dim %10296, %c0 : tensor + %10297 = arith.index_cast %dim_3620 : index to i64 + %from_elements_3621 = tensor.from_elements %10297 : tensor<1xi64> + %10298 = stablehlo.dynamic_reshape %10296, %from_elements_3621 : (tensor, tensor<1xi64>) -> tensor + %dim_3622 = tensor.dim %10298, %c0 : tensor + %10299 = arith.index_cast %dim_3622 : index to i64 + %from_elements_3623 = tensor.from_elements %10299, %c1_i64 : tensor<2xi64> + %10300 = stablehlo.dynamic_reshape %10298, %from_elements_3623 : (tensor, tensor<2xi64>) -> tensor + %dim_3624 = tensor.dim %10300, %c0 : tensor + %10301 = arith.index_cast %dim_3624 : index to i64 + %from_elements_3625 = tensor.from_elements %c1_i64, %10301, %c4096_i64 : tensor<3xi64> + %10302 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3625, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3626 = tensor.dim %10302, %c1 : tensor<1x?x4096xi64> + %10303 = arith.index_cast %dim_3626 : index to i64 + %from_elements_3627 = tensor.from_elements %c1_i64, %10303, %c4096_i64, %c1_i64 : tensor<4xi64> + %10304 = stablehlo.dynamic_reshape %10302, %from_elements_3627 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10305 = stablehlo.dynamic_broadcast_in_dim %10300, %from_elements_3625, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3628 = tensor.dim %10305, %c1 : tensor<1x?x4096xi64> + %10306 = arith.index_cast %dim_3628 : index to i64 + %from_elements_3629 = tensor.from_elements %c1_i64, %10306, %c4096_i64, %c1_i64 : tensor<4xi64> + %10307 = stablehlo.dynamic_reshape %10305, %from_elements_3629 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10308 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3625, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3630 = tensor.dim %10308, %c1 : tensor<1x?x4096xi64> + %10309 = arith.index_cast %dim_3630 : index to i64 + %from_elements_3631 = tensor.from_elements %c1_i64, %10309, %c4096_i64, %c1_i64 : tensor<4xi64> + %10310 = stablehlo.dynamic_reshape %10308, %from_elements_3631 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10311 = stablehlo.concatenate %10304, %10307, %10310, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10312 = "stablehlo.gather"(%9982, %10311) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10313 = shape.shape_of %10312 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10314 = shape.num_elements %10313 : tensor<3xindex> -> index + %10315 = stablehlo.compute_reshape_shape %10314, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10316 = stablehlo.dynamic_reshape %10312, %10315 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10317 = stablehlo.dot %10316, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10318 = stablehlo.logistic %10317 : tensor + %10319 = shape.shape_of %10318 : tensor -> tensor<2xindex> + %10320 = shape.shape_of %10317 : tensor -> tensor<2xindex> + %10321 = shape.cstr_broadcastable %10319, %10320 : tensor<2xindex>, tensor<2xindex> + %10322 = shape.assuming %10321 -> (tensor) { + %19688 = shape.broadcast %10319, %10320 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10318, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10317, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10323 = shape.shape_of %10322 : tensor -> tensor<2xindex> + %10324 = shape.cstr_broadcastable %10323, %10320 : tensor<2xindex>, tensor<2xindex> + %10325 = shape.assuming %10324 -> (tensor) { + %19688 = shape.broadcast %10323, %10320 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10322, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10317, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10326 = stablehlo.dot %10325, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3632 = tensor.dim %10298, %c0 : tensor + %10327 = arith.index_cast %dim_3632 : index to i64 + %from_elements_3633 = tensor.from_elements %10327, %c1_i64 : tensor<2xi64> + %10328 = stablehlo.dynamic_reshape %10298, %from_elements_3633 : (tensor, tensor<2xi64>) -> tensor + %dim_3634 = tensor.dim %10295, %c0 : tensor + %10329 = arith.index_cast %dim_3634 : index to i64 + %from_elements_3635 = tensor.from_elements %10329, %c1_i64 : tensor<2xi64> + %10330 = stablehlo.dynamic_reshape %10295, %from_elements_3635 : (tensor, tensor<2xi64>) -> tensor + %10331 = stablehlo.concatenate %10328, %10330, dim = 1 : (tensor, tensor) -> tensor + %10332 = "stablehlo.gather"(%10011, %10331) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10333 = shape.shape_of %10326 : tensor -> tensor<2xindex> + %10334 = shape.shape_of %10332 : tensor -> tensor<2xindex> + %10335 = shape.cstr_broadcastable %10333, %10334 : tensor<2xindex>, tensor<2xindex> + %10336 = shape.assuming %10335 -> (tensor) { + %19688 = shape.broadcast %10333, %10334 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10326, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10332, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10337 = shape.shape_of %10336 : tensor -> tensor<2xindex> + %10338 = stablehlo.dynamic_broadcast_in_dim %10336, %10337, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10339 = stablehlo.dynamic_broadcast_in_dim %213, %10337, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10340 = stablehlo.multiply %10338, %10339 : tensor + %dim_3636 = tensor.dim %10300, %c0 : tensor + %10341 = arith.index_cast %dim_3636 : index to i64 + %dim_3637 = tensor.dim %10336, %c0 : tensor + %10342 = arith.index_cast %dim_3637 : index to i64 + %10343 = arith.maxsi %10341, %10342 : i64 + %10344 = arith.index_cast %10343 : i64 to index + %from_elements_3638 = tensor.from_elements %10344, %c4096 : tensor<2xindex> + %10345 = stablehlo.dynamic_broadcast_in_dim %10300, %from_elements_3638, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3639 = tensor.dim %10345, %c0 : tensor + %10346 = arith.index_cast %dim_3639 : index to i64 + %from_elements_3640 = tensor.from_elements %10346, %c4096_i64 : tensor<2xi64> + %10347 = stablehlo.real_dynamic_slice %10340, %c_22, %from_elements_3640, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3641 = tensor.from_elements %10346, %c4096_i64, %c1_i64 : tensor<3xi64> + %10348 = stablehlo.dynamic_reshape %10345, %from_elements_3641 : (tensor, tensor<3xi64>) -> tensor + %10349 = stablehlo.dynamic_iota %from_elements_3641, dim = 1 : (tensor<3xi64>) -> tensor + %10350 = stablehlo.concatenate %10348, %10349, dim = 2 : (tensor, tensor) -> tensor + %10351 = "stablehlo.scatter"(%10288, %10350, %10347) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10352 = stablehlo.slice %9971 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10353 = stablehlo.reshape %10352 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10354 = stablehlo.custom_call @byteir.non_zero(%10353) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3642 = tensor.dim %10354, %c0 : tensor + %10355 = arith.index_cast %dim_3642 : index to i64 + %from_elements_3643 = tensor.from_elements %10355, %c1_i64 : tensor<2xi64> + %10356 = stablehlo.real_dynamic_slice %10354, %c_22, %from_elements_3643, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3644 = tensor.dim %10356, %c0 : tensor + %10357 = arith.index_cast %dim_3644 : index to i64 + %from_elements_3645 = tensor.from_elements %10357 : tensor<1xi64> + %10358 = stablehlo.dynamic_reshape %10356, %from_elements_3645 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3646 = tensor.from_elements %10355, %c2_i64 : tensor<2xi64> + %10359 = stablehlo.real_dynamic_slice %10354, %c_24, %from_elements_3646, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3647 = tensor.dim %10359, %c0 : tensor + %10360 = arith.index_cast %dim_3647 : index to i64 + %from_elements_3648 = tensor.from_elements %10360 : tensor<1xi64> + %10361 = stablehlo.dynamic_reshape %10359, %from_elements_3648 : (tensor, tensor<1xi64>) -> tensor + %dim_3649 = tensor.dim %10361, %c0 : tensor + %10362 = arith.index_cast %dim_3649 : index to i64 + %from_elements_3650 = tensor.from_elements %10362, %c1_i64 : tensor<2xi64> + %10363 = stablehlo.dynamic_reshape %10361, %from_elements_3650 : (tensor, tensor<2xi64>) -> tensor + %dim_3651 = tensor.dim %10363, %c0 : tensor + %10364 = arith.index_cast %dim_3651 : index to i64 + %from_elements_3652 = tensor.from_elements %c1_i64, %10364, %c4096_i64 : tensor<3xi64> + %10365 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3652, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3653 = tensor.dim %10365, %c1 : tensor<1x?x4096xi64> + %10366 = arith.index_cast %dim_3653 : index to i64 + %from_elements_3654 = tensor.from_elements %c1_i64, %10366, %c4096_i64, %c1_i64 : tensor<4xi64> + %10367 = stablehlo.dynamic_reshape %10365, %from_elements_3654 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10368 = stablehlo.dynamic_broadcast_in_dim %10363, %from_elements_3652, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3655 = tensor.dim %10368, %c1 : tensor<1x?x4096xi64> + %10369 = arith.index_cast %dim_3655 : index to i64 + %from_elements_3656 = tensor.from_elements %c1_i64, %10369, %c4096_i64, %c1_i64 : tensor<4xi64> + %10370 = stablehlo.dynamic_reshape %10368, %from_elements_3656 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10371 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3652, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3657 = tensor.dim %10371, %c1 : tensor<1x?x4096xi64> + %10372 = arith.index_cast %dim_3657 : index to i64 + %from_elements_3658 = tensor.from_elements %c1_i64, %10372, %c4096_i64, %c1_i64 : tensor<4xi64> + %10373 = stablehlo.dynamic_reshape %10371, %from_elements_3658 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10374 = stablehlo.concatenate %10367, %10370, %10373, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10375 = "stablehlo.gather"(%9982, %10374) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10376 = shape.shape_of %10375 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10377 = shape.num_elements %10376 : tensor<3xindex> -> index + %10378 = stablehlo.compute_reshape_shape %10377, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10379 = stablehlo.dynamic_reshape %10375, %10378 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10380 = stablehlo.dot %10379, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10381 = stablehlo.logistic %10380 : tensor + %10382 = shape.shape_of %10381 : tensor -> tensor<2xindex> + %10383 = shape.shape_of %10380 : tensor -> tensor<2xindex> + %10384 = shape.cstr_broadcastable %10382, %10383 : tensor<2xindex>, tensor<2xindex> + %10385 = shape.assuming %10384 -> (tensor) { + %19688 = shape.broadcast %10382, %10383 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10381, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10380, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10386 = shape.shape_of %10385 : tensor -> tensor<2xindex> + %10387 = shape.cstr_broadcastable %10386, %10383 : tensor<2xindex>, tensor<2xindex> + %10388 = shape.assuming %10387 -> (tensor) { + %19688 = shape.broadcast %10386, %10383 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10385, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10380, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10389 = stablehlo.dot %10388, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3659 = tensor.dim %10361, %c0 : tensor + %10390 = arith.index_cast %dim_3659 : index to i64 + %from_elements_3660 = tensor.from_elements %10390, %c1_i64 : tensor<2xi64> + %10391 = stablehlo.dynamic_reshape %10361, %from_elements_3660 : (tensor, tensor<2xi64>) -> tensor + %dim_3661 = tensor.dim %10358, %c0 : tensor + %10392 = arith.index_cast %dim_3661 : index to i64 + %from_elements_3662 = tensor.from_elements %10392, %c1_i64 : tensor<2xi64> + %10393 = stablehlo.dynamic_reshape %10358, %from_elements_3662 : (tensor, tensor<2xi64>) -> tensor + %10394 = stablehlo.concatenate %10391, %10393, dim = 1 : (tensor, tensor) -> tensor + %10395 = "stablehlo.gather"(%10011, %10394) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10396 = shape.shape_of %10389 : tensor -> tensor<2xindex> + %10397 = shape.shape_of %10395 : tensor -> tensor<2xindex> + %10398 = shape.cstr_broadcastable %10396, %10397 : tensor<2xindex>, tensor<2xindex> + %10399 = shape.assuming %10398 -> (tensor) { + %19688 = shape.broadcast %10396, %10397 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10389, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10395, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10400 = shape.shape_of %10399 : tensor -> tensor<2xindex> + %10401 = stablehlo.dynamic_broadcast_in_dim %10399, %10400, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10402 = stablehlo.dynamic_broadcast_in_dim %213, %10400, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10403 = stablehlo.multiply %10401, %10402 : tensor + %dim_3663 = tensor.dim %10363, %c0 : tensor + %10404 = arith.index_cast %dim_3663 : index to i64 + %dim_3664 = tensor.dim %10399, %c0 : tensor + %10405 = arith.index_cast %dim_3664 : index to i64 + %10406 = arith.maxsi %10404, %10405 : i64 + %10407 = arith.index_cast %10406 : i64 to index + %from_elements_3665 = tensor.from_elements %10407, %c4096 : tensor<2xindex> + %10408 = stablehlo.dynamic_broadcast_in_dim %10363, %from_elements_3665, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3666 = tensor.dim %10408, %c0 : tensor + %10409 = arith.index_cast %dim_3666 : index to i64 + %from_elements_3667 = tensor.from_elements %10409, %c4096_i64 : tensor<2xi64> + %10410 = stablehlo.real_dynamic_slice %10403, %c_22, %from_elements_3667, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3668 = tensor.from_elements %10409, %c4096_i64, %c1_i64 : tensor<3xi64> + %10411 = stablehlo.dynamic_reshape %10408, %from_elements_3668 : (tensor, tensor<3xi64>) -> tensor + %10412 = stablehlo.dynamic_iota %from_elements_3668, dim = 1 : (tensor<3xi64>) -> tensor + %10413 = stablehlo.concatenate %10411, %10412, dim = 2 : (tensor, tensor) -> tensor + %10414 = "stablehlo.scatter"(%10351, %10413, %10410) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10415 = stablehlo.slice %9971 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10416 = stablehlo.reshape %10415 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10417 = stablehlo.custom_call @byteir.non_zero(%10416) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3669 = tensor.dim %10417, %c0 : tensor + %10418 = arith.index_cast %dim_3669 : index to i64 + %from_elements_3670 = tensor.from_elements %10418, %c1_i64 : tensor<2xi64> + %10419 = stablehlo.real_dynamic_slice %10417, %c_22, %from_elements_3670, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3671 = tensor.dim %10419, %c0 : tensor + %10420 = arith.index_cast %dim_3671 : index to i64 + %from_elements_3672 = tensor.from_elements %10420 : tensor<1xi64> + %10421 = stablehlo.dynamic_reshape %10419, %from_elements_3672 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3673 = tensor.from_elements %10418, %c2_i64 : tensor<2xi64> + %10422 = stablehlo.real_dynamic_slice %10417, %c_24, %from_elements_3673, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3674 = tensor.dim %10422, %c0 : tensor + %10423 = arith.index_cast %dim_3674 : index to i64 + %from_elements_3675 = tensor.from_elements %10423 : tensor<1xi64> + %10424 = stablehlo.dynamic_reshape %10422, %from_elements_3675 : (tensor, tensor<1xi64>) -> tensor + %dim_3676 = tensor.dim %10424, %c0 : tensor + %10425 = arith.index_cast %dim_3676 : index to i64 + %from_elements_3677 = tensor.from_elements %10425, %c1_i64 : tensor<2xi64> + %10426 = stablehlo.dynamic_reshape %10424, %from_elements_3677 : (tensor, tensor<2xi64>) -> tensor + %dim_3678 = tensor.dim %10426, %c0 : tensor + %10427 = arith.index_cast %dim_3678 : index to i64 + %from_elements_3679 = tensor.from_elements %c1_i64, %10427, %c4096_i64 : tensor<3xi64> + %10428 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3679, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3680 = tensor.dim %10428, %c1 : tensor<1x?x4096xi64> + %10429 = arith.index_cast %dim_3680 : index to i64 + %from_elements_3681 = tensor.from_elements %c1_i64, %10429, %c4096_i64, %c1_i64 : tensor<4xi64> + %10430 = stablehlo.dynamic_reshape %10428, %from_elements_3681 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10431 = stablehlo.dynamic_broadcast_in_dim %10426, %from_elements_3679, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3682 = tensor.dim %10431, %c1 : tensor<1x?x4096xi64> + %10432 = arith.index_cast %dim_3682 : index to i64 + %from_elements_3683 = tensor.from_elements %c1_i64, %10432, %c4096_i64, %c1_i64 : tensor<4xi64> + %10433 = stablehlo.dynamic_reshape %10431, %from_elements_3683 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10434 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3679, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3684 = tensor.dim %10434, %c1 : tensor<1x?x4096xi64> + %10435 = arith.index_cast %dim_3684 : index to i64 + %from_elements_3685 = tensor.from_elements %c1_i64, %10435, %c4096_i64, %c1_i64 : tensor<4xi64> + %10436 = stablehlo.dynamic_reshape %10434, %from_elements_3685 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10437 = stablehlo.concatenate %10430, %10433, %10436, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10438 = "stablehlo.gather"(%9982, %10437) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10439 = shape.shape_of %10438 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10440 = shape.num_elements %10439 : tensor<3xindex> -> index + %10441 = stablehlo.compute_reshape_shape %10440, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10442 = stablehlo.dynamic_reshape %10438, %10441 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10443 = stablehlo.dot %10442, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10444 = stablehlo.logistic %10443 : tensor + %10445 = shape.shape_of %10444 : tensor -> tensor<2xindex> + %10446 = shape.shape_of %10443 : tensor -> tensor<2xindex> + %10447 = shape.cstr_broadcastable %10445, %10446 : tensor<2xindex>, tensor<2xindex> + %10448 = shape.assuming %10447 -> (tensor) { + %19688 = shape.broadcast %10445, %10446 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10444, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10443, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10449 = shape.shape_of %10448 : tensor -> tensor<2xindex> + %10450 = shape.cstr_broadcastable %10449, %10446 : tensor<2xindex>, tensor<2xindex> + %10451 = shape.assuming %10450 -> (tensor) { + %19688 = shape.broadcast %10449, %10446 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10448, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10443, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10452 = stablehlo.dot %10451, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3686 = tensor.dim %10424, %c0 : tensor + %10453 = arith.index_cast %dim_3686 : index to i64 + %from_elements_3687 = tensor.from_elements %10453, %c1_i64 : tensor<2xi64> + %10454 = stablehlo.dynamic_reshape %10424, %from_elements_3687 : (tensor, tensor<2xi64>) -> tensor + %dim_3688 = tensor.dim %10421, %c0 : tensor + %10455 = arith.index_cast %dim_3688 : index to i64 + %from_elements_3689 = tensor.from_elements %10455, %c1_i64 : tensor<2xi64> + %10456 = stablehlo.dynamic_reshape %10421, %from_elements_3689 : (tensor, tensor<2xi64>) -> tensor + %10457 = stablehlo.concatenate %10454, %10456, dim = 1 : (tensor, tensor) -> tensor + %10458 = "stablehlo.gather"(%10011, %10457) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10459 = shape.shape_of %10452 : tensor -> tensor<2xindex> + %10460 = shape.shape_of %10458 : tensor -> tensor<2xindex> + %10461 = shape.cstr_broadcastable %10459, %10460 : tensor<2xindex>, tensor<2xindex> + %10462 = shape.assuming %10461 -> (tensor) { + %19688 = shape.broadcast %10459, %10460 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10452, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10458, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10463 = shape.shape_of %10462 : tensor -> tensor<2xindex> + %10464 = stablehlo.dynamic_broadcast_in_dim %10462, %10463, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10465 = stablehlo.dynamic_broadcast_in_dim %213, %10463, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10466 = stablehlo.multiply %10464, %10465 : tensor + %dim_3690 = tensor.dim %10426, %c0 : tensor + %10467 = arith.index_cast %dim_3690 : index to i64 + %dim_3691 = tensor.dim %10462, %c0 : tensor + %10468 = arith.index_cast %dim_3691 : index to i64 + %10469 = arith.maxsi %10467, %10468 : i64 + %10470 = arith.index_cast %10469 : i64 to index + %from_elements_3692 = tensor.from_elements %10470, %c4096 : tensor<2xindex> + %10471 = stablehlo.dynamic_broadcast_in_dim %10426, %from_elements_3692, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3693 = tensor.dim %10471, %c0 : tensor + %10472 = arith.index_cast %dim_3693 : index to i64 + %from_elements_3694 = tensor.from_elements %10472, %c4096_i64 : tensor<2xi64> + %10473 = stablehlo.real_dynamic_slice %10466, %c_22, %from_elements_3694, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3695 = tensor.from_elements %10472, %c4096_i64, %c1_i64 : tensor<3xi64> + %10474 = stablehlo.dynamic_reshape %10471, %from_elements_3695 : (tensor, tensor<3xi64>) -> tensor + %10475 = stablehlo.dynamic_iota %from_elements_3695, dim = 1 : (tensor<3xi64>) -> tensor + %10476 = stablehlo.concatenate %10474, %10475, dim = 2 : (tensor, tensor) -> tensor + %10477 = "stablehlo.scatter"(%10414, %10476, %10473) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10478 = stablehlo.reshape %10477 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %10479 = stablehlo.add %9944, %10478 : tensor<3x1x4096xf32> + %10480 = stablehlo.broadcast_in_dim %10479, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %10481 = stablehlo.power %10480, %15 : tensor<3x1x4096xf32> + %10482 = stablehlo.reduce(%10481 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %10483 = stablehlo.reshape %10482 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %10484 = stablehlo.broadcast_in_dim %10483, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %10485 = stablehlo.divide %10484, %21 : tensor<3x1x1xf32> + %10486 = stablehlo.broadcast_in_dim %10485, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %10487 = stablehlo.add %10486, %25 : tensor<3x1x1xf32> + %10488 = stablehlo.rsqrt %10487 : tensor<3x1x1xf32> + %10489 = stablehlo.broadcast_in_dim %10488, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %10490 = stablehlo.multiply %10480, %10489 : tensor<3x1x4096xf32> + %10491 = stablehlo.broadcast_in_dim %10490, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %10492 = stablehlo.multiply %10491, %31 : tensor<3x1x4096xf32> + %10493 = stablehlo.reshape %10492 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %10494 = stablehlo.dot %10493, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %10495 = stablehlo.reshape %10494 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %10496 = stablehlo.dot %10493, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %10497 = stablehlo.reshape %10496 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %10498 = stablehlo.reshape %10495 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %10499 = stablehlo.transpose %10498, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %10500 = stablehlo.reshape %10497 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %10501 = stablehlo.transpose %10500, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %10502 = stablehlo.slice %arg34 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %10503 = stablehlo.slice %arg35 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %10504 = "stablehlo.gather"(%10502, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %10505 = stablehlo.reshape %10504 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %10506 = "stablehlo.gather"(%10503, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %10507 = stablehlo.reshape %10506 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %10508 = stablehlo.broadcast_in_dim %10499, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %10509 = stablehlo.broadcast_in_dim %10505, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %10510 = stablehlo.multiply %10508, %10509 : tensor<3x32x1x128xf32> + %10511 = stablehlo.slice %10499 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %10512 = stablehlo.slice %10499 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %10513 = stablehlo.negate %10512 : tensor<3x32x1x64xf32> + %10514 = stablehlo.concatenate %10513, %10511, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %10515 = stablehlo.broadcast_in_dim %10514, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %10516 = stablehlo.broadcast_in_dim %10507, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %10517 = stablehlo.multiply %10515, %10516 : tensor<3x32x1x128xf32> + %10518 = stablehlo.add %10510, %10517 : tensor<3x32x1x128xf32> + %10519 = stablehlo.broadcast_in_dim %10501, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %10520 = stablehlo.broadcast_in_dim %10505, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %10521 = stablehlo.multiply %10519, %10520 : tensor<3x8x1x128xf32> + %10522 = stablehlo.slice %10501 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %10523 = stablehlo.slice %10501 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %10524 = stablehlo.negate %10523 : tensor<3x8x1x64xf32> + %10525 = stablehlo.concatenate %10524, %10522, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %10526 = stablehlo.broadcast_in_dim %10525, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %10527 = stablehlo.broadcast_in_dim %10507, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %10528 = stablehlo.multiply %10526, %10527 : tensor<3x8x1x128xf32> + %10529 = stablehlo.add %10521, %10528 : tensor<3x8x1x128xf32> + %10530 = stablehlo.concatenate %arg99, %10529, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %10531 = stablehlo.concatenate %arg100, %10501, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %10532 = stablehlo.reshape %10530 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %10533 = stablehlo.broadcast_in_dim %10532, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %10534 = stablehlo.reshape %10533 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %10535 = stablehlo.reshape %10531 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %10536 = stablehlo.broadcast_in_dim %10535, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %10537 = stablehlo.reshape %10536 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %10538 = stablehlo.transpose %10534, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %10539 = stablehlo.reshape %10518 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %10540 = stablehlo.reshape %10538 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %10541 = stablehlo.broadcast_in_dim %10540, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %10542 = stablehlo.dot_general %10539, %10541, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %10543 = stablehlo.reshape %10542 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %10544 = stablehlo.broadcast_in_dim %10543, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %10545 = stablehlo.divide %10544, %89 : tensor<3x32x1x8xf32> + %10546 = stablehlo.custom_call @byteir.softmax(%10545) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %10547 = stablehlo.reshape %10546 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %10548 = stablehlo.reshape %10537 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %10549 = stablehlo.broadcast_in_dim %10548, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %10550 = stablehlo.dot_general %10547, %10549, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %10551 = stablehlo.reshape %10550 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %10552 = stablehlo.transpose %10551, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %10553 = stablehlo.reshape %10552 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %10554 = stablehlo.reshape %10553 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %10555 = stablehlo.dot %10554, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %10556 = stablehlo.reshape %10555 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %10557 = stablehlo.add %10479, %10556 : tensor<3x1x4096xf32> + %10558 = stablehlo.broadcast_in_dim %10557, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %10559 = stablehlo.power %10558, %15 : tensor<3x1x4096xf32> + %10560 = stablehlo.reduce(%10559 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %10561 = stablehlo.reshape %10560 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %10562 = stablehlo.broadcast_in_dim %10561, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %10563 = stablehlo.divide %10562, %21 : tensor<3x1x1xf32> + %10564 = stablehlo.broadcast_in_dim %10563, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %10565 = stablehlo.add %10564, %25 : tensor<3x1x1xf32> + %10566 = stablehlo.rsqrt %10565 : tensor<3x1x1xf32> + %10567 = stablehlo.broadcast_in_dim %10566, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %10568 = stablehlo.multiply %10558, %10567 : tensor<3x1x4096xf32> + %10569 = stablehlo.broadcast_in_dim %10568, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %10570 = stablehlo.multiply %10569, %31 : tensor<3x1x4096xf32> + %10571 = stablehlo.reshape %10570 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %10572 = stablehlo.dot %10571, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %10573 = stablehlo.custom_call @byteir.softmax(%10572) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %10574:2 = stablehlo.custom_call @byteir.top_k(%10573) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %10575 = stablehlo.reduce(%10574#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %10576 = stablehlo.reshape %10575 : (tensor<3xf32>) -> tensor<3x1xf32> + %10577 = stablehlo.broadcast_in_dim %10574#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %10578 = stablehlo.broadcast_in_dim %10576, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %10579 = stablehlo.divide %10577, %10578 : tensor<3x2xf32> + %10580 = stablehlo.reshape %10574#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %10581 = stablehlo.broadcast_in_dim %10580, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %10582 = stablehlo.compare EQ, %10581, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %10583 = stablehlo.convert %10582 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %10584 = stablehlo.transpose %10583, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %10585 = stablehlo.slice %10584 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10586 = stablehlo.reshape %10585 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10587 = stablehlo.custom_call @byteir.non_zero(%10586) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3696 = tensor.dim %10587, %c0 : tensor + %10588 = arith.index_cast %dim_3696 : index to i64 + %from_elements_3697 = tensor.from_elements %10588, %c1_i64 : tensor<2xi64> + %10589 = stablehlo.real_dynamic_slice %10587, %c_22, %from_elements_3697, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3698 = tensor.dim %10589, %c0 : tensor + %10590 = arith.index_cast %dim_3698 : index to i64 + %from_elements_3699 = tensor.from_elements %10590 : tensor<1xi64> + %10591 = stablehlo.dynamic_reshape %10589, %from_elements_3699 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3700 = tensor.from_elements %10588, %c2_i64 : tensor<2xi64> + %10592 = stablehlo.real_dynamic_slice %10587, %c_24, %from_elements_3700, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3701 = tensor.dim %10592, %c0 : tensor + %10593 = arith.index_cast %dim_3701 : index to i64 + %from_elements_3702 = tensor.from_elements %10593 : tensor<1xi64> + %10594 = stablehlo.dynamic_reshape %10592, %from_elements_3702 : (tensor, tensor<1xi64>) -> tensor + %10595 = stablehlo.reshape %10571 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_3703 = tensor.dim %10594, %c0 : tensor + %10596 = arith.index_cast %dim_3703 : index to i64 + %from_elements_3704 = tensor.from_elements %10596, %c1_i64 : tensor<2xi64> + %10597 = stablehlo.dynamic_reshape %10594, %from_elements_3704 : (tensor, tensor<2xi64>) -> tensor + %dim_3705 = tensor.dim %10597, %c0 : tensor + %10598 = arith.index_cast %dim_3705 : index to i64 + %from_elements_3706 = tensor.from_elements %c1_i64, %10598, %c4096_i64 : tensor<3xi64> + %10599 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3706, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3707 = tensor.dim %10599, %c1 : tensor<1x?x4096xi64> + %10600 = arith.index_cast %dim_3707 : index to i64 + %from_elements_3708 = tensor.from_elements %c1_i64, %10600, %c4096_i64, %c1_i64 : tensor<4xi64> + %10601 = stablehlo.dynamic_reshape %10599, %from_elements_3708 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10602 = stablehlo.dynamic_broadcast_in_dim %10597, %from_elements_3706, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3709 = tensor.dim %10602, %c1 : tensor<1x?x4096xi64> + %10603 = arith.index_cast %dim_3709 : index to i64 + %from_elements_3710 = tensor.from_elements %c1_i64, %10603, %c4096_i64, %c1_i64 : tensor<4xi64> + %10604 = stablehlo.dynamic_reshape %10602, %from_elements_3710 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10605 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3706, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3711 = tensor.dim %10605, %c1 : tensor<1x?x4096xi64> + %10606 = arith.index_cast %dim_3711 : index to i64 + %from_elements_3712 = tensor.from_elements %c1_i64, %10606, %c4096_i64, %c1_i64 : tensor<4xi64> + %10607 = stablehlo.dynamic_reshape %10605, %from_elements_3712 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10608 = stablehlo.concatenate %10601, %10604, %10607, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10609 = "stablehlo.gather"(%10595, %10608) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10610 = shape.shape_of %10609 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10611 = shape.num_elements %10610 : tensor<3xindex> -> index + %10612 = stablehlo.compute_reshape_shape %10611, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10613 = stablehlo.dynamic_reshape %10609, %10612 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10614 = stablehlo.dot %10613, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10615 = stablehlo.logistic %10614 : tensor + %10616 = shape.shape_of %10615 : tensor -> tensor<2xindex> + %10617 = shape.shape_of %10614 : tensor -> tensor<2xindex> + %10618 = shape.cstr_broadcastable %10616, %10617 : tensor<2xindex>, tensor<2xindex> + %10619 = shape.assuming %10618 -> (tensor) { + %19688 = shape.broadcast %10616, %10617 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10615, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10614, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10620 = shape.shape_of %10619 : tensor -> tensor<2xindex> + %10621 = shape.cstr_broadcastable %10620, %10617 : tensor<2xindex>, tensor<2xindex> + %10622 = shape.assuming %10621 -> (tensor) { + %19688 = shape.broadcast %10620, %10617 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10619, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10614, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10623 = stablehlo.dot %10622, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %10624 = stablehlo.reshape %10579 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_3713 = tensor.dim %10594, %c0 : tensor + %10625 = arith.index_cast %dim_3713 : index to i64 + %from_elements_3714 = tensor.from_elements %10625, %c1_i64 : tensor<2xi64> + %10626 = stablehlo.dynamic_reshape %10594, %from_elements_3714 : (tensor, tensor<2xi64>) -> tensor + %dim_3715 = tensor.dim %10591, %c0 : tensor + %10627 = arith.index_cast %dim_3715 : index to i64 + %from_elements_3716 = tensor.from_elements %10627, %c1_i64 : tensor<2xi64> + %10628 = stablehlo.dynamic_reshape %10591, %from_elements_3716 : (tensor, tensor<2xi64>) -> tensor + %10629 = stablehlo.concatenate %10626, %10628, dim = 1 : (tensor, tensor) -> tensor + %10630 = "stablehlo.gather"(%10624, %10629) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10631 = shape.shape_of %10623 : tensor -> tensor<2xindex> + %10632 = shape.shape_of %10630 : tensor -> tensor<2xindex> + %10633 = shape.cstr_broadcastable %10631, %10632 : tensor<2xindex>, tensor<2xindex> + %10634 = shape.assuming %10633 -> (tensor) { + %19688 = shape.broadcast %10631, %10632 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10623, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10630, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10635 = shape.shape_of %10634 : tensor -> tensor<2xindex> + %10636 = stablehlo.dynamic_broadcast_in_dim %10634, %10635, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10637 = stablehlo.dynamic_broadcast_in_dim %213, %10635, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10638 = stablehlo.multiply %10636, %10637 : tensor + %dim_3717 = tensor.dim %10597, %c0 : tensor + %10639 = arith.index_cast %dim_3717 : index to i64 + %dim_3718 = tensor.dim %10634, %c0 : tensor + %10640 = arith.index_cast %dim_3718 : index to i64 + %10641 = arith.maxsi %10639, %10640 : i64 + %10642 = arith.index_cast %10641 : i64 to index + %from_elements_3719 = tensor.from_elements %10642, %c4096 : tensor<2xindex> + %10643 = stablehlo.dynamic_broadcast_in_dim %10597, %from_elements_3719, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3720 = tensor.dim %10643, %c0 : tensor + %10644 = arith.index_cast %dim_3720 : index to i64 + %from_elements_3721 = tensor.from_elements %10644, %c4096_i64 : tensor<2xi64> + %10645 = stablehlo.real_dynamic_slice %10638, %c_22, %from_elements_3721, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3722 = tensor.from_elements %10644, %c4096_i64, %c1_i64 : tensor<3xi64> + %10646 = stablehlo.dynamic_reshape %10643, %from_elements_3722 : (tensor, tensor<3xi64>) -> tensor + %10647 = stablehlo.dynamic_iota %from_elements_3722, dim = 1 : (tensor<3xi64>) -> tensor + %10648 = stablehlo.concatenate %10646, %10647, dim = 2 : (tensor, tensor) -> tensor + %10649 = "stablehlo.scatter"(%cst_2, %10648, %10645) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10650 = stablehlo.slice %10584 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10651 = stablehlo.reshape %10650 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10652 = stablehlo.custom_call @byteir.non_zero(%10651) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3723 = tensor.dim %10652, %c0 : tensor + %10653 = arith.index_cast %dim_3723 : index to i64 + %from_elements_3724 = tensor.from_elements %10653, %c1_i64 : tensor<2xi64> + %10654 = stablehlo.real_dynamic_slice %10652, %c_22, %from_elements_3724, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3725 = tensor.dim %10654, %c0 : tensor + %10655 = arith.index_cast %dim_3725 : index to i64 + %from_elements_3726 = tensor.from_elements %10655 : tensor<1xi64> + %10656 = stablehlo.dynamic_reshape %10654, %from_elements_3726 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3727 = tensor.from_elements %10653, %c2_i64 : tensor<2xi64> + %10657 = stablehlo.real_dynamic_slice %10652, %c_24, %from_elements_3727, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3728 = tensor.dim %10657, %c0 : tensor + %10658 = arith.index_cast %dim_3728 : index to i64 + %from_elements_3729 = tensor.from_elements %10658 : tensor<1xi64> + %10659 = stablehlo.dynamic_reshape %10657, %from_elements_3729 : (tensor, tensor<1xi64>) -> tensor + %dim_3730 = tensor.dim %10659, %c0 : tensor + %10660 = arith.index_cast %dim_3730 : index to i64 + %from_elements_3731 = tensor.from_elements %10660, %c1_i64 : tensor<2xi64> + %10661 = stablehlo.dynamic_reshape %10659, %from_elements_3731 : (tensor, tensor<2xi64>) -> tensor + %dim_3732 = tensor.dim %10661, %c0 : tensor + %10662 = arith.index_cast %dim_3732 : index to i64 + %from_elements_3733 = tensor.from_elements %c1_i64, %10662, %c4096_i64 : tensor<3xi64> + %10663 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3733, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3734 = tensor.dim %10663, %c1 : tensor<1x?x4096xi64> + %10664 = arith.index_cast %dim_3734 : index to i64 + %from_elements_3735 = tensor.from_elements %c1_i64, %10664, %c4096_i64, %c1_i64 : tensor<4xi64> + %10665 = stablehlo.dynamic_reshape %10663, %from_elements_3735 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10666 = stablehlo.dynamic_broadcast_in_dim %10661, %from_elements_3733, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3736 = tensor.dim %10666, %c1 : tensor<1x?x4096xi64> + %10667 = arith.index_cast %dim_3736 : index to i64 + %from_elements_3737 = tensor.from_elements %c1_i64, %10667, %c4096_i64, %c1_i64 : tensor<4xi64> + %10668 = stablehlo.dynamic_reshape %10666, %from_elements_3737 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10669 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3733, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3738 = tensor.dim %10669, %c1 : tensor<1x?x4096xi64> + %10670 = arith.index_cast %dim_3738 : index to i64 + %from_elements_3739 = tensor.from_elements %c1_i64, %10670, %c4096_i64, %c1_i64 : tensor<4xi64> + %10671 = stablehlo.dynamic_reshape %10669, %from_elements_3739 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10672 = stablehlo.concatenate %10665, %10668, %10671, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10673 = "stablehlo.gather"(%10595, %10672) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10674 = shape.shape_of %10673 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10675 = shape.num_elements %10674 : tensor<3xindex> -> index + %10676 = stablehlo.compute_reshape_shape %10675, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10677 = stablehlo.dynamic_reshape %10673, %10676 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10678 = stablehlo.dot %10677, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10679 = stablehlo.logistic %10678 : tensor + %10680 = shape.shape_of %10679 : tensor -> tensor<2xindex> + %10681 = shape.shape_of %10678 : tensor -> tensor<2xindex> + %10682 = shape.cstr_broadcastable %10680, %10681 : tensor<2xindex>, tensor<2xindex> + %10683 = shape.assuming %10682 -> (tensor) { + %19688 = shape.broadcast %10680, %10681 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10679, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10678, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10684 = shape.shape_of %10683 : tensor -> tensor<2xindex> + %10685 = shape.cstr_broadcastable %10684, %10681 : tensor<2xindex>, tensor<2xindex> + %10686 = shape.assuming %10685 -> (tensor) { + %19688 = shape.broadcast %10684, %10681 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10683, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10678, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10687 = stablehlo.dot %10686, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3740 = tensor.dim %10659, %c0 : tensor + %10688 = arith.index_cast %dim_3740 : index to i64 + %from_elements_3741 = tensor.from_elements %10688, %c1_i64 : tensor<2xi64> + %10689 = stablehlo.dynamic_reshape %10659, %from_elements_3741 : (tensor, tensor<2xi64>) -> tensor + %dim_3742 = tensor.dim %10656, %c0 : tensor + %10690 = arith.index_cast %dim_3742 : index to i64 + %from_elements_3743 = tensor.from_elements %10690, %c1_i64 : tensor<2xi64> + %10691 = stablehlo.dynamic_reshape %10656, %from_elements_3743 : (tensor, tensor<2xi64>) -> tensor + %10692 = stablehlo.concatenate %10689, %10691, dim = 1 : (tensor, tensor) -> tensor + %10693 = "stablehlo.gather"(%10624, %10692) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10694 = shape.shape_of %10687 : tensor -> tensor<2xindex> + %10695 = shape.shape_of %10693 : tensor -> tensor<2xindex> + %10696 = shape.cstr_broadcastable %10694, %10695 : tensor<2xindex>, tensor<2xindex> + %10697 = shape.assuming %10696 -> (tensor) { + %19688 = shape.broadcast %10694, %10695 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10687, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10693, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10698 = shape.shape_of %10697 : tensor -> tensor<2xindex> + %10699 = stablehlo.dynamic_broadcast_in_dim %10697, %10698, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10700 = stablehlo.dynamic_broadcast_in_dim %213, %10698, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10701 = stablehlo.multiply %10699, %10700 : tensor + %dim_3744 = tensor.dim %10661, %c0 : tensor + %10702 = arith.index_cast %dim_3744 : index to i64 + %dim_3745 = tensor.dim %10697, %c0 : tensor + %10703 = arith.index_cast %dim_3745 : index to i64 + %10704 = arith.maxsi %10702, %10703 : i64 + %10705 = arith.index_cast %10704 : i64 to index + %from_elements_3746 = tensor.from_elements %10705, %c4096 : tensor<2xindex> + %10706 = stablehlo.dynamic_broadcast_in_dim %10661, %from_elements_3746, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3747 = tensor.dim %10706, %c0 : tensor + %10707 = arith.index_cast %dim_3747 : index to i64 + %from_elements_3748 = tensor.from_elements %10707, %c4096_i64 : tensor<2xi64> + %10708 = stablehlo.real_dynamic_slice %10701, %c_22, %from_elements_3748, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3749 = tensor.from_elements %10707, %c4096_i64, %c1_i64 : tensor<3xi64> + %10709 = stablehlo.dynamic_reshape %10706, %from_elements_3749 : (tensor, tensor<3xi64>) -> tensor + %10710 = stablehlo.dynamic_iota %from_elements_3749, dim = 1 : (tensor<3xi64>) -> tensor + %10711 = stablehlo.concatenate %10709, %10710, dim = 2 : (tensor, tensor) -> tensor + %10712 = "stablehlo.scatter"(%10649, %10711, %10708) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10713 = stablehlo.slice %10584 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10714 = stablehlo.reshape %10713 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10715 = stablehlo.custom_call @byteir.non_zero(%10714) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3750 = tensor.dim %10715, %c0 : tensor + %10716 = arith.index_cast %dim_3750 : index to i64 + %from_elements_3751 = tensor.from_elements %10716, %c1_i64 : tensor<2xi64> + %10717 = stablehlo.real_dynamic_slice %10715, %c_22, %from_elements_3751, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3752 = tensor.dim %10717, %c0 : tensor + %10718 = arith.index_cast %dim_3752 : index to i64 + %from_elements_3753 = tensor.from_elements %10718 : tensor<1xi64> + %10719 = stablehlo.dynamic_reshape %10717, %from_elements_3753 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3754 = tensor.from_elements %10716, %c2_i64 : tensor<2xi64> + %10720 = stablehlo.real_dynamic_slice %10715, %c_24, %from_elements_3754, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3755 = tensor.dim %10720, %c0 : tensor + %10721 = arith.index_cast %dim_3755 : index to i64 + %from_elements_3756 = tensor.from_elements %10721 : tensor<1xi64> + %10722 = stablehlo.dynamic_reshape %10720, %from_elements_3756 : (tensor, tensor<1xi64>) -> tensor + %dim_3757 = tensor.dim %10722, %c0 : tensor + %10723 = arith.index_cast %dim_3757 : index to i64 + %from_elements_3758 = tensor.from_elements %10723, %c1_i64 : tensor<2xi64> + %10724 = stablehlo.dynamic_reshape %10722, %from_elements_3758 : (tensor, tensor<2xi64>) -> tensor + %dim_3759 = tensor.dim %10724, %c0 : tensor + %10725 = arith.index_cast %dim_3759 : index to i64 + %from_elements_3760 = tensor.from_elements %c1_i64, %10725, %c4096_i64 : tensor<3xi64> + %10726 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3760, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3761 = tensor.dim %10726, %c1 : tensor<1x?x4096xi64> + %10727 = arith.index_cast %dim_3761 : index to i64 + %from_elements_3762 = tensor.from_elements %c1_i64, %10727, %c4096_i64, %c1_i64 : tensor<4xi64> + %10728 = stablehlo.dynamic_reshape %10726, %from_elements_3762 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10729 = stablehlo.dynamic_broadcast_in_dim %10724, %from_elements_3760, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3763 = tensor.dim %10729, %c1 : tensor<1x?x4096xi64> + %10730 = arith.index_cast %dim_3763 : index to i64 + %from_elements_3764 = tensor.from_elements %c1_i64, %10730, %c4096_i64, %c1_i64 : tensor<4xi64> + %10731 = stablehlo.dynamic_reshape %10729, %from_elements_3764 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10732 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3760, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3765 = tensor.dim %10732, %c1 : tensor<1x?x4096xi64> + %10733 = arith.index_cast %dim_3765 : index to i64 + %from_elements_3766 = tensor.from_elements %c1_i64, %10733, %c4096_i64, %c1_i64 : tensor<4xi64> + %10734 = stablehlo.dynamic_reshape %10732, %from_elements_3766 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10735 = stablehlo.concatenate %10728, %10731, %10734, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10736 = "stablehlo.gather"(%10595, %10735) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10737 = shape.shape_of %10736 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10738 = shape.num_elements %10737 : tensor<3xindex> -> index + %10739 = stablehlo.compute_reshape_shape %10738, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10740 = stablehlo.dynamic_reshape %10736, %10739 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10741 = stablehlo.dot %10740, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10742 = stablehlo.logistic %10741 : tensor + %10743 = shape.shape_of %10742 : tensor -> tensor<2xindex> + %10744 = shape.shape_of %10741 : tensor -> tensor<2xindex> + %10745 = shape.cstr_broadcastable %10743, %10744 : tensor<2xindex>, tensor<2xindex> + %10746 = shape.assuming %10745 -> (tensor) { + %19688 = shape.broadcast %10743, %10744 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10742, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10741, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10747 = shape.shape_of %10746 : tensor -> tensor<2xindex> + %10748 = shape.cstr_broadcastable %10747, %10744 : tensor<2xindex>, tensor<2xindex> + %10749 = shape.assuming %10748 -> (tensor) { + %19688 = shape.broadcast %10747, %10744 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10746, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10741, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10750 = stablehlo.dot %10749, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3767 = tensor.dim %10722, %c0 : tensor + %10751 = arith.index_cast %dim_3767 : index to i64 + %from_elements_3768 = tensor.from_elements %10751, %c1_i64 : tensor<2xi64> + %10752 = stablehlo.dynamic_reshape %10722, %from_elements_3768 : (tensor, tensor<2xi64>) -> tensor + %dim_3769 = tensor.dim %10719, %c0 : tensor + %10753 = arith.index_cast %dim_3769 : index to i64 + %from_elements_3770 = tensor.from_elements %10753, %c1_i64 : tensor<2xi64> + %10754 = stablehlo.dynamic_reshape %10719, %from_elements_3770 : (tensor, tensor<2xi64>) -> tensor + %10755 = stablehlo.concatenate %10752, %10754, dim = 1 : (tensor, tensor) -> tensor + %10756 = "stablehlo.gather"(%10624, %10755) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10757 = shape.shape_of %10750 : tensor -> tensor<2xindex> + %10758 = shape.shape_of %10756 : tensor -> tensor<2xindex> + %10759 = shape.cstr_broadcastable %10757, %10758 : tensor<2xindex>, tensor<2xindex> + %10760 = shape.assuming %10759 -> (tensor) { + %19688 = shape.broadcast %10757, %10758 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10750, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10756, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10761 = shape.shape_of %10760 : tensor -> tensor<2xindex> + %10762 = stablehlo.dynamic_broadcast_in_dim %10760, %10761, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10763 = stablehlo.dynamic_broadcast_in_dim %213, %10761, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10764 = stablehlo.multiply %10762, %10763 : tensor + %dim_3771 = tensor.dim %10724, %c0 : tensor + %10765 = arith.index_cast %dim_3771 : index to i64 + %dim_3772 = tensor.dim %10760, %c0 : tensor + %10766 = arith.index_cast %dim_3772 : index to i64 + %10767 = arith.maxsi %10765, %10766 : i64 + %10768 = arith.index_cast %10767 : i64 to index + %from_elements_3773 = tensor.from_elements %10768, %c4096 : tensor<2xindex> + %10769 = stablehlo.dynamic_broadcast_in_dim %10724, %from_elements_3773, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3774 = tensor.dim %10769, %c0 : tensor + %10770 = arith.index_cast %dim_3774 : index to i64 + %from_elements_3775 = tensor.from_elements %10770, %c4096_i64 : tensor<2xi64> + %10771 = stablehlo.real_dynamic_slice %10764, %c_22, %from_elements_3775, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3776 = tensor.from_elements %10770, %c4096_i64, %c1_i64 : tensor<3xi64> + %10772 = stablehlo.dynamic_reshape %10769, %from_elements_3776 : (tensor, tensor<3xi64>) -> tensor + %10773 = stablehlo.dynamic_iota %from_elements_3776, dim = 1 : (tensor<3xi64>) -> tensor + %10774 = stablehlo.concatenate %10772, %10773, dim = 2 : (tensor, tensor) -> tensor + %10775 = "stablehlo.scatter"(%10712, %10774, %10771) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10776 = stablehlo.slice %10584 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10777 = stablehlo.reshape %10776 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10778 = stablehlo.custom_call @byteir.non_zero(%10777) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3777 = tensor.dim %10778, %c0 : tensor + %10779 = arith.index_cast %dim_3777 : index to i64 + %from_elements_3778 = tensor.from_elements %10779, %c1_i64 : tensor<2xi64> + %10780 = stablehlo.real_dynamic_slice %10778, %c_22, %from_elements_3778, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3779 = tensor.dim %10780, %c0 : tensor + %10781 = arith.index_cast %dim_3779 : index to i64 + %from_elements_3780 = tensor.from_elements %10781 : tensor<1xi64> + %10782 = stablehlo.dynamic_reshape %10780, %from_elements_3780 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3781 = tensor.from_elements %10779, %c2_i64 : tensor<2xi64> + %10783 = stablehlo.real_dynamic_slice %10778, %c_24, %from_elements_3781, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3782 = tensor.dim %10783, %c0 : tensor + %10784 = arith.index_cast %dim_3782 : index to i64 + %from_elements_3783 = tensor.from_elements %10784 : tensor<1xi64> + %10785 = stablehlo.dynamic_reshape %10783, %from_elements_3783 : (tensor, tensor<1xi64>) -> tensor + %dim_3784 = tensor.dim %10785, %c0 : tensor + %10786 = arith.index_cast %dim_3784 : index to i64 + %from_elements_3785 = tensor.from_elements %10786, %c1_i64 : tensor<2xi64> + %10787 = stablehlo.dynamic_reshape %10785, %from_elements_3785 : (tensor, tensor<2xi64>) -> tensor + %dim_3786 = tensor.dim %10787, %c0 : tensor + %10788 = arith.index_cast %dim_3786 : index to i64 + %from_elements_3787 = tensor.from_elements %c1_i64, %10788, %c4096_i64 : tensor<3xi64> + %10789 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3787, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3788 = tensor.dim %10789, %c1 : tensor<1x?x4096xi64> + %10790 = arith.index_cast %dim_3788 : index to i64 + %from_elements_3789 = tensor.from_elements %c1_i64, %10790, %c4096_i64, %c1_i64 : tensor<4xi64> + %10791 = stablehlo.dynamic_reshape %10789, %from_elements_3789 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10792 = stablehlo.dynamic_broadcast_in_dim %10787, %from_elements_3787, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3790 = tensor.dim %10792, %c1 : tensor<1x?x4096xi64> + %10793 = arith.index_cast %dim_3790 : index to i64 + %from_elements_3791 = tensor.from_elements %c1_i64, %10793, %c4096_i64, %c1_i64 : tensor<4xi64> + %10794 = stablehlo.dynamic_reshape %10792, %from_elements_3791 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10795 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3787, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3792 = tensor.dim %10795, %c1 : tensor<1x?x4096xi64> + %10796 = arith.index_cast %dim_3792 : index to i64 + %from_elements_3793 = tensor.from_elements %c1_i64, %10796, %c4096_i64, %c1_i64 : tensor<4xi64> + %10797 = stablehlo.dynamic_reshape %10795, %from_elements_3793 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10798 = stablehlo.concatenate %10791, %10794, %10797, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10799 = "stablehlo.gather"(%10595, %10798) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10800 = shape.shape_of %10799 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10801 = shape.num_elements %10800 : tensor<3xindex> -> index + %10802 = stablehlo.compute_reshape_shape %10801, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10803 = stablehlo.dynamic_reshape %10799, %10802 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10804 = stablehlo.dot %10803, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10805 = stablehlo.logistic %10804 : tensor + %10806 = shape.shape_of %10805 : tensor -> tensor<2xindex> + %10807 = shape.shape_of %10804 : tensor -> tensor<2xindex> + %10808 = shape.cstr_broadcastable %10806, %10807 : tensor<2xindex>, tensor<2xindex> + %10809 = shape.assuming %10808 -> (tensor) { + %19688 = shape.broadcast %10806, %10807 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10805, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10804, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10810 = shape.shape_of %10809 : tensor -> tensor<2xindex> + %10811 = shape.cstr_broadcastable %10810, %10807 : tensor<2xindex>, tensor<2xindex> + %10812 = shape.assuming %10811 -> (tensor) { + %19688 = shape.broadcast %10810, %10807 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10809, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10804, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10813 = stablehlo.dot %10812, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3794 = tensor.dim %10785, %c0 : tensor + %10814 = arith.index_cast %dim_3794 : index to i64 + %from_elements_3795 = tensor.from_elements %10814, %c1_i64 : tensor<2xi64> + %10815 = stablehlo.dynamic_reshape %10785, %from_elements_3795 : (tensor, tensor<2xi64>) -> tensor + %dim_3796 = tensor.dim %10782, %c0 : tensor + %10816 = arith.index_cast %dim_3796 : index to i64 + %from_elements_3797 = tensor.from_elements %10816, %c1_i64 : tensor<2xi64> + %10817 = stablehlo.dynamic_reshape %10782, %from_elements_3797 : (tensor, tensor<2xi64>) -> tensor + %10818 = stablehlo.concatenate %10815, %10817, dim = 1 : (tensor, tensor) -> tensor + %10819 = "stablehlo.gather"(%10624, %10818) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10820 = shape.shape_of %10813 : tensor -> tensor<2xindex> + %10821 = shape.shape_of %10819 : tensor -> tensor<2xindex> + %10822 = shape.cstr_broadcastable %10820, %10821 : tensor<2xindex>, tensor<2xindex> + %10823 = shape.assuming %10822 -> (tensor) { + %19688 = shape.broadcast %10820, %10821 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10813, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10819, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10824 = shape.shape_of %10823 : tensor -> tensor<2xindex> + %10825 = stablehlo.dynamic_broadcast_in_dim %10823, %10824, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10826 = stablehlo.dynamic_broadcast_in_dim %213, %10824, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10827 = stablehlo.multiply %10825, %10826 : tensor + %dim_3798 = tensor.dim %10787, %c0 : tensor + %10828 = arith.index_cast %dim_3798 : index to i64 + %dim_3799 = tensor.dim %10823, %c0 : tensor + %10829 = arith.index_cast %dim_3799 : index to i64 + %10830 = arith.maxsi %10828, %10829 : i64 + %10831 = arith.index_cast %10830 : i64 to index + %from_elements_3800 = tensor.from_elements %10831, %c4096 : tensor<2xindex> + %10832 = stablehlo.dynamic_broadcast_in_dim %10787, %from_elements_3800, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3801 = tensor.dim %10832, %c0 : tensor + %10833 = arith.index_cast %dim_3801 : index to i64 + %from_elements_3802 = tensor.from_elements %10833, %c4096_i64 : tensor<2xi64> + %10834 = stablehlo.real_dynamic_slice %10827, %c_22, %from_elements_3802, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3803 = tensor.from_elements %10833, %c4096_i64, %c1_i64 : tensor<3xi64> + %10835 = stablehlo.dynamic_reshape %10832, %from_elements_3803 : (tensor, tensor<3xi64>) -> tensor + %10836 = stablehlo.dynamic_iota %from_elements_3803, dim = 1 : (tensor<3xi64>) -> tensor + %10837 = stablehlo.concatenate %10835, %10836, dim = 2 : (tensor, tensor) -> tensor + %10838 = "stablehlo.scatter"(%10775, %10837, %10834) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10839 = stablehlo.slice %10584 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10840 = stablehlo.reshape %10839 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10841 = stablehlo.custom_call @byteir.non_zero(%10840) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3804 = tensor.dim %10841, %c0 : tensor + %10842 = arith.index_cast %dim_3804 : index to i64 + %from_elements_3805 = tensor.from_elements %10842, %c1_i64 : tensor<2xi64> + %10843 = stablehlo.real_dynamic_slice %10841, %c_22, %from_elements_3805, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3806 = tensor.dim %10843, %c0 : tensor + %10844 = arith.index_cast %dim_3806 : index to i64 + %from_elements_3807 = tensor.from_elements %10844 : tensor<1xi64> + %10845 = stablehlo.dynamic_reshape %10843, %from_elements_3807 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3808 = tensor.from_elements %10842, %c2_i64 : tensor<2xi64> + %10846 = stablehlo.real_dynamic_slice %10841, %c_24, %from_elements_3808, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3809 = tensor.dim %10846, %c0 : tensor + %10847 = arith.index_cast %dim_3809 : index to i64 + %from_elements_3810 = tensor.from_elements %10847 : tensor<1xi64> + %10848 = stablehlo.dynamic_reshape %10846, %from_elements_3810 : (tensor, tensor<1xi64>) -> tensor + %dim_3811 = tensor.dim %10848, %c0 : tensor + %10849 = arith.index_cast %dim_3811 : index to i64 + %from_elements_3812 = tensor.from_elements %10849, %c1_i64 : tensor<2xi64> + %10850 = stablehlo.dynamic_reshape %10848, %from_elements_3812 : (tensor, tensor<2xi64>) -> tensor + %dim_3813 = tensor.dim %10850, %c0 : tensor + %10851 = arith.index_cast %dim_3813 : index to i64 + %from_elements_3814 = tensor.from_elements %c1_i64, %10851, %c4096_i64 : tensor<3xi64> + %10852 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3814, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3815 = tensor.dim %10852, %c1 : tensor<1x?x4096xi64> + %10853 = arith.index_cast %dim_3815 : index to i64 + %from_elements_3816 = tensor.from_elements %c1_i64, %10853, %c4096_i64, %c1_i64 : tensor<4xi64> + %10854 = stablehlo.dynamic_reshape %10852, %from_elements_3816 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10855 = stablehlo.dynamic_broadcast_in_dim %10850, %from_elements_3814, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3817 = tensor.dim %10855, %c1 : tensor<1x?x4096xi64> + %10856 = arith.index_cast %dim_3817 : index to i64 + %from_elements_3818 = tensor.from_elements %c1_i64, %10856, %c4096_i64, %c1_i64 : tensor<4xi64> + %10857 = stablehlo.dynamic_reshape %10855, %from_elements_3818 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10858 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3814, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3819 = tensor.dim %10858, %c1 : tensor<1x?x4096xi64> + %10859 = arith.index_cast %dim_3819 : index to i64 + %from_elements_3820 = tensor.from_elements %c1_i64, %10859, %c4096_i64, %c1_i64 : tensor<4xi64> + %10860 = stablehlo.dynamic_reshape %10858, %from_elements_3820 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10861 = stablehlo.concatenate %10854, %10857, %10860, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10862 = "stablehlo.gather"(%10595, %10861) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10863 = shape.shape_of %10862 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10864 = shape.num_elements %10863 : tensor<3xindex> -> index + %10865 = stablehlo.compute_reshape_shape %10864, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10866 = stablehlo.dynamic_reshape %10862, %10865 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10867 = stablehlo.dot %10866, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10868 = stablehlo.logistic %10867 : tensor + %10869 = shape.shape_of %10868 : tensor -> tensor<2xindex> + %10870 = shape.shape_of %10867 : tensor -> tensor<2xindex> + %10871 = shape.cstr_broadcastable %10869, %10870 : tensor<2xindex>, tensor<2xindex> + %10872 = shape.assuming %10871 -> (tensor) { + %19688 = shape.broadcast %10869, %10870 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10868, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10867, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10873 = shape.shape_of %10872 : tensor -> tensor<2xindex> + %10874 = shape.cstr_broadcastable %10873, %10870 : tensor<2xindex>, tensor<2xindex> + %10875 = shape.assuming %10874 -> (tensor) { + %19688 = shape.broadcast %10873, %10870 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10872, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10867, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10876 = stablehlo.dot %10875, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3821 = tensor.dim %10848, %c0 : tensor + %10877 = arith.index_cast %dim_3821 : index to i64 + %from_elements_3822 = tensor.from_elements %10877, %c1_i64 : tensor<2xi64> + %10878 = stablehlo.dynamic_reshape %10848, %from_elements_3822 : (tensor, tensor<2xi64>) -> tensor + %dim_3823 = tensor.dim %10845, %c0 : tensor + %10879 = arith.index_cast %dim_3823 : index to i64 + %from_elements_3824 = tensor.from_elements %10879, %c1_i64 : tensor<2xi64> + %10880 = stablehlo.dynamic_reshape %10845, %from_elements_3824 : (tensor, tensor<2xi64>) -> tensor + %10881 = stablehlo.concatenate %10878, %10880, dim = 1 : (tensor, tensor) -> tensor + %10882 = "stablehlo.gather"(%10624, %10881) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10883 = shape.shape_of %10876 : tensor -> tensor<2xindex> + %10884 = shape.shape_of %10882 : tensor -> tensor<2xindex> + %10885 = shape.cstr_broadcastable %10883, %10884 : tensor<2xindex>, tensor<2xindex> + %10886 = shape.assuming %10885 -> (tensor) { + %19688 = shape.broadcast %10883, %10884 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10876, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10882, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10887 = shape.shape_of %10886 : tensor -> tensor<2xindex> + %10888 = stablehlo.dynamic_broadcast_in_dim %10886, %10887, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10889 = stablehlo.dynamic_broadcast_in_dim %213, %10887, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10890 = stablehlo.multiply %10888, %10889 : tensor + %dim_3825 = tensor.dim %10850, %c0 : tensor + %10891 = arith.index_cast %dim_3825 : index to i64 + %dim_3826 = tensor.dim %10886, %c0 : tensor + %10892 = arith.index_cast %dim_3826 : index to i64 + %10893 = arith.maxsi %10891, %10892 : i64 + %10894 = arith.index_cast %10893 : i64 to index + %from_elements_3827 = tensor.from_elements %10894, %c4096 : tensor<2xindex> + %10895 = stablehlo.dynamic_broadcast_in_dim %10850, %from_elements_3827, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3828 = tensor.dim %10895, %c0 : tensor + %10896 = arith.index_cast %dim_3828 : index to i64 + %from_elements_3829 = tensor.from_elements %10896, %c4096_i64 : tensor<2xi64> + %10897 = stablehlo.real_dynamic_slice %10890, %c_22, %from_elements_3829, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3830 = tensor.from_elements %10896, %c4096_i64, %c1_i64 : tensor<3xi64> + %10898 = stablehlo.dynamic_reshape %10895, %from_elements_3830 : (tensor, tensor<3xi64>) -> tensor + %10899 = stablehlo.dynamic_iota %from_elements_3830, dim = 1 : (tensor<3xi64>) -> tensor + %10900 = stablehlo.concatenate %10898, %10899, dim = 2 : (tensor, tensor) -> tensor + %10901 = "stablehlo.scatter"(%10838, %10900, %10897) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10902 = stablehlo.slice %10584 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10903 = stablehlo.reshape %10902 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10904 = stablehlo.custom_call @byteir.non_zero(%10903) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3831 = tensor.dim %10904, %c0 : tensor + %10905 = arith.index_cast %dim_3831 : index to i64 + %from_elements_3832 = tensor.from_elements %10905, %c1_i64 : tensor<2xi64> + %10906 = stablehlo.real_dynamic_slice %10904, %c_22, %from_elements_3832, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3833 = tensor.dim %10906, %c0 : tensor + %10907 = arith.index_cast %dim_3833 : index to i64 + %from_elements_3834 = tensor.from_elements %10907 : tensor<1xi64> + %10908 = stablehlo.dynamic_reshape %10906, %from_elements_3834 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3835 = tensor.from_elements %10905, %c2_i64 : tensor<2xi64> + %10909 = stablehlo.real_dynamic_slice %10904, %c_24, %from_elements_3835, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3836 = tensor.dim %10909, %c0 : tensor + %10910 = arith.index_cast %dim_3836 : index to i64 + %from_elements_3837 = tensor.from_elements %10910 : tensor<1xi64> + %10911 = stablehlo.dynamic_reshape %10909, %from_elements_3837 : (tensor, tensor<1xi64>) -> tensor + %dim_3838 = tensor.dim %10911, %c0 : tensor + %10912 = arith.index_cast %dim_3838 : index to i64 + %from_elements_3839 = tensor.from_elements %10912, %c1_i64 : tensor<2xi64> + %10913 = stablehlo.dynamic_reshape %10911, %from_elements_3839 : (tensor, tensor<2xi64>) -> tensor + %dim_3840 = tensor.dim %10913, %c0 : tensor + %10914 = arith.index_cast %dim_3840 : index to i64 + %from_elements_3841 = tensor.from_elements %c1_i64, %10914, %c4096_i64 : tensor<3xi64> + %10915 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3841, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3842 = tensor.dim %10915, %c1 : tensor<1x?x4096xi64> + %10916 = arith.index_cast %dim_3842 : index to i64 + %from_elements_3843 = tensor.from_elements %c1_i64, %10916, %c4096_i64, %c1_i64 : tensor<4xi64> + %10917 = stablehlo.dynamic_reshape %10915, %from_elements_3843 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10918 = stablehlo.dynamic_broadcast_in_dim %10913, %from_elements_3841, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3844 = tensor.dim %10918, %c1 : tensor<1x?x4096xi64> + %10919 = arith.index_cast %dim_3844 : index to i64 + %from_elements_3845 = tensor.from_elements %c1_i64, %10919, %c4096_i64, %c1_i64 : tensor<4xi64> + %10920 = stablehlo.dynamic_reshape %10918, %from_elements_3845 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10921 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3841, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3846 = tensor.dim %10921, %c1 : tensor<1x?x4096xi64> + %10922 = arith.index_cast %dim_3846 : index to i64 + %from_elements_3847 = tensor.from_elements %c1_i64, %10922, %c4096_i64, %c1_i64 : tensor<4xi64> + %10923 = stablehlo.dynamic_reshape %10921, %from_elements_3847 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10924 = stablehlo.concatenate %10917, %10920, %10923, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10925 = "stablehlo.gather"(%10595, %10924) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10926 = shape.shape_of %10925 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10927 = shape.num_elements %10926 : tensor<3xindex> -> index + %10928 = stablehlo.compute_reshape_shape %10927, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10929 = stablehlo.dynamic_reshape %10925, %10928 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10930 = stablehlo.dot %10929, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10931 = stablehlo.logistic %10930 : tensor + %10932 = shape.shape_of %10931 : tensor -> tensor<2xindex> + %10933 = shape.shape_of %10930 : tensor -> tensor<2xindex> + %10934 = shape.cstr_broadcastable %10932, %10933 : tensor<2xindex>, tensor<2xindex> + %10935 = shape.assuming %10934 -> (tensor) { + %19688 = shape.broadcast %10932, %10933 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10931, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10930, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10936 = shape.shape_of %10935 : tensor -> tensor<2xindex> + %10937 = shape.cstr_broadcastable %10936, %10933 : tensor<2xindex>, tensor<2xindex> + %10938 = shape.assuming %10937 -> (tensor) { + %19688 = shape.broadcast %10936, %10933 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10935, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10930, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10939 = stablehlo.dot %10938, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3848 = tensor.dim %10911, %c0 : tensor + %10940 = arith.index_cast %dim_3848 : index to i64 + %from_elements_3849 = tensor.from_elements %10940, %c1_i64 : tensor<2xi64> + %10941 = stablehlo.dynamic_reshape %10911, %from_elements_3849 : (tensor, tensor<2xi64>) -> tensor + %dim_3850 = tensor.dim %10908, %c0 : tensor + %10942 = arith.index_cast %dim_3850 : index to i64 + %from_elements_3851 = tensor.from_elements %10942, %c1_i64 : tensor<2xi64> + %10943 = stablehlo.dynamic_reshape %10908, %from_elements_3851 : (tensor, tensor<2xi64>) -> tensor + %10944 = stablehlo.concatenate %10941, %10943, dim = 1 : (tensor, tensor) -> tensor + %10945 = "stablehlo.gather"(%10624, %10944) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %10946 = shape.shape_of %10939 : tensor -> tensor<2xindex> + %10947 = shape.shape_of %10945 : tensor -> tensor<2xindex> + %10948 = shape.cstr_broadcastable %10946, %10947 : tensor<2xindex>, tensor<2xindex> + %10949 = shape.assuming %10948 -> (tensor) { + %19688 = shape.broadcast %10946, %10947 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10939, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10945, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10950 = shape.shape_of %10949 : tensor -> tensor<2xindex> + %10951 = stablehlo.dynamic_broadcast_in_dim %10949, %10950, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %10952 = stablehlo.dynamic_broadcast_in_dim %213, %10950, dims = [] : (tensor, tensor<2xindex>) -> tensor + %10953 = stablehlo.multiply %10951, %10952 : tensor + %dim_3852 = tensor.dim %10913, %c0 : tensor + %10954 = arith.index_cast %dim_3852 : index to i64 + %dim_3853 = tensor.dim %10949, %c0 : tensor + %10955 = arith.index_cast %dim_3853 : index to i64 + %10956 = arith.maxsi %10954, %10955 : i64 + %10957 = arith.index_cast %10956 : i64 to index + %from_elements_3854 = tensor.from_elements %10957, %c4096 : tensor<2xindex> + %10958 = stablehlo.dynamic_broadcast_in_dim %10913, %from_elements_3854, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3855 = tensor.dim %10958, %c0 : tensor + %10959 = arith.index_cast %dim_3855 : index to i64 + %from_elements_3856 = tensor.from_elements %10959, %c4096_i64 : tensor<2xi64> + %10960 = stablehlo.real_dynamic_slice %10953, %c_22, %from_elements_3856, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3857 = tensor.from_elements %10959, %c4096_i64, %c1_i64 : tensor<3xi64> + %10961 = stablehlo.dynamic_reshape %10958, %from_elements_3857 : (tensor, tensor<3xi64>) -> tensor + %10962 = stablehlo.dynamic_iota %from_elements_3857, dim = 1 : (tensor<3xi64>) -> tensor + %10963 = stablehlo.concatenate %10961, %10962, dim = 2 : (tensor, tensor) -> tensor + %10964 = "stablehlo.scatter"(%10901, %10963, %10960) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %10965 = stablehlo.slice %10584 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %10966 = stablehlo.reshape %10965 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %10967 = stablehlo.custom_call @byteir.non_zero(%10966) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3858 = tensor.dim %10967, %c0 : tensor + %10968 = arith.index_cast %dim_3858 : index to i64 + %from_elements_3859 = tensor.from_elements %10968, %c1_i64 : tensor<2xi64> + %10969 = stablehlo.real_dynamic_slice %10967, %c_22, %from_elements_3859, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3860 = tensor.dim %10969, %c0 : tensor + %10970 = arith.index_cast %dim_3860 : index to i64 + %from_elements_3861 = tensor.from_elements %10970 : tensor<1xi64> + %10971 = stablehlo.dynamic_reshape %10969, %from_elements_3861 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3862 = tensor.from_elements %10968, %c2_i64 : tensor<2xi64> + %10972 = stablehlo.real_dynamic_slice %10967, %c_24, %from_elements_3862, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3863 = tensor.dim %10972, %c0 : tensor + %10973 = arith.index_cast %dim_3863 : index to i64 + %from_elements_3864 = tensor.from_elements %10973 : tensor<1xi64> + %10974 = stablehlo.dynamic_reshape %10972, %from_elements_3864 : (tensor, tensor<1xi64>) -> tensor + %dim_3865 = tensor.dim %10974, %c0 : tensor + %10975 = arith.index_cast %dim_3865 : index to i64 + %from_elements_3866 = tensor.from_elements %10975, %c1_i64 : tensor<2xi64> + %10976 = stablehlo.dynamic_reshape %10974, %from_elements_3866 : (tensor, tensor<2xi64>) -> tensor + %dim_3867 = tensor.dim %10976, %c0 : tensor + %10977 = arith.index_cast %dim_3867 : index to i64 + %from_elements_3868 = tensor.from_elements %c1_i64, %10977, %c4096_i64 : tensor<3xi64> + %10978 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3868, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3869 = tensor.dim %10978, %c1 : tensor<1x?x4096xi64> + %10979 = arith.index_cast %dim_3869 : index to i64 + %from_elements_3870 = tensor.from_elements %c1_i64, %10979, %c4096_i64, %c1_i64 : tensor<4xi64> + %10980 = stablehlo.dynamic_reshape %10978, %from_elements_3870 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10981 = stablehlo.dynamic_broadcast_in_dim %10976, %from_elements_3868, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3871 = tensor.dim %10981, %c1 : tensor<1x?x4096xi64> + %10982 = arith.index_cast %dim_3871 : index to i64 + %from_elements_3872 = tensor.from_elements %c1_i64, %10982, %c4096_i64, %c1_i64 : tensor<4xi64> + %10983 = stablehlo.dynamic_reshape %10981, %from_elements_3872 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10984 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3868, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3873 = tensor.dim %10984, %c1 : tensor<1x?x4096xi64> + %10985 = arith.index_cast %dim_3873 : index to i64 + %from_elements_3874 = tensor.from_elements %c1_i64, %10985, %c4096_i64, %c1_i64 : tensor<4xi64> + %10986 = stablehlo.dynamic_reshape %10984, %from_elements_3874 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %10987 = stablehlo.concatenate %10980, %10983, %10986, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %10988 = "stablehlo.gather"(%10595, %10987) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %10989 = shape.shape_of %10988 : tensor<1x?x4096xf32> -> tensor<3xindex> + %10990 = shape.num_elements %10989 : tensor<3xindex> -> index + %10991 = stablehlo.compute_reshape_shape %10990, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %10992 = stablehlo.dynamic_reshape %10988, %10991 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %10993 = stablehlo.dot %10992, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %10994 = stablehlo.logistic %10993 : tensor + %10995 = shape.shape_of %10994 : tensor -> tensor<2xindex> + %10996 = shape.shape_of %10993 : tensor -> tensor<2xindex> + %10997 = shape.cstr_broadcastable %10995, %10996 : tensor<2xindex>, tensor<2xindex> + %10998 = shape.assuming %10997 -> (tensor) { + %19688 = shape.broadcast %10995, %10996 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10994, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10993, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %10999 = shape.shape_of %10998 : tensor -> tensor<2xindex> + %11000 = shape.cstr_broadcastable %10999, %10996 : tensor<2xindex>, tensor<2xindex> + %11001 = shape.assuming %11000 -> (tensor) { + %19688 = shape.broadcast %10999, %10996 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %10998, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %10993, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11002 = stablehlo.dot %11001, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3875 = tensor.dim %10974, %c0 : tensor + %11003 = arith.index_cast %dim_3875 : index to i64 + %from_elements_3876 = tensor.from_elements %11003, %c1_i64 : tensor<2xi64> + %11004 = stablehlo.dynamic_reshape %10974, %from_elements_3876 : (tensor, tensor<2xi64>) -> tensor + %dim_3877 = tensor.dim %10971, %c0 : tensor + %11005 = arith.index_cast %dim_3877 : index to i64 + %from_elements_3878 = tensor.from_elements %11005, %c1_i64 : tensor<2xi64> + %11006 = stablehlo.dynamic_reshape %10971, %from_elements_3878 : (tensor, tensor<2xi64>) -> tensor + %11007 = stablehlo.concatenate %11004, %11006, dim = 1 : (tensor, tensor) -> tensor + %11008 = "stablehlo.gather"(%10624, %11007) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11009 = shape.shape_of %11002 : tensor -> tensor<2xindex> + %11010 = shape.shape_of %11008 : tensor -> tensor<2xindex> + %11011 = shape.cstr_broadcastable %11009, %11010 : tensor<2xindex>, tensor<2xindex> + %11012 = shape.assuming %11011 -> (tensor) { + %19688 = shape.broadcast %11009, %11010 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11002, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11008, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11013 = shape.shape_of %11012 : tensor -> tensor<2xindex> + %11014 = stablehlo.dynamic_broadcast_in_dim %11012, %11013, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11015 = stablehlo.dynamic_broadcast_in_dim %213, %11013, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11016 = stablehlo.multiply %11014, %11015 : tensor + %dim_3879 = tensor.dim %10976, %c0 : tensor + %11017 = arith.index_cast %dim_3879 : index to i64 + %dim_3880 = tensor.dim %11012, %c0 : tensor + %11018 = arith.index_cast %dim_3880 : index to i64 + %11019 = arith.maxsi %11017, %11018 : i64 + %11020 = arith.index_cast %11019 : i64 to index + %from_elements_3881 = tensor.from_elements %11020, %c4096 : tensor<2xindex> + %11021 = stablehlo.dynamic_broadcast_in_dim %10976, %from_elements_3881, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3882 = tensor.dim %11021, %c0 : tensor + %11022 = arith.index_cast %dim_3882 : index to i64 + %from_elements_3883 = tensor.from_elements %11022, %c4096_i64 : tensor<2xi64> + %11023 = stablehlo.real_dynamic_slice %11016, %c_22, %from_elements_3883, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3884 = tensor.from_elements %11022, %c4096_i64, %c1_i64 : tensor<3xi64> + %11024 = stablehlo.dynamic_reshape %11021, %from_elements_3884 : (tensor, tensor<3xi64>) -> tensor + %11025 = stablehlo.dynamic_iota %from_elements_3884, dim = 1 : (tensor<3xi64>) -> tensor + %11026 = stablehlo.concatenate %11024, %11025, dim = 2 : (tensor, tensor) -> tensor + %11027 = "stablehlo.scatter"(%10964, %11026, %11023) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11028 = stablehlo.slice %10584 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11029 = stablehlo.reshape %11028 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11030 = stablehlo.custom_call @byteir.non_zero(%11029) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3885 = tensor.dim %11030, %c0 : tensor + %11031 = arith.index_cast %dim_3885 : index to i64 + %from_elements_3886 = tensor.from_elements %11031, %c1_i64 : tensor<2xi64> + %11032 = stablehlo.real_dynamic_slice %11030, %c_22, %from_elements_3886, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3887 = tensor.dim %11032, %c0 : tensor + %11033 = arith.index_cast %dim_3887 : index to i64 + %from_elements_3888 = tensor.from_elements %11033 : tensor<1xi64> + %11034 = stablehlo.dynamic_reshape %11032, %from_elements_3888 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3889 = tensor.from_elements %11031, %c2_i64 : tensor<2xi64> + %11035 = stablehlo.real_dynamic_slice %11030, %c_24, %from_elements_3889, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3890 = tensor.dim %11035, %c0 : tensor + %11036 = arith.index_cast %dim_3890 : index to i64 + %from_elements_3891 = tensor.from_elements %11036 : tensor<1xi64> + %11037 = stablehlo.dynamic_reshape %11035, %from_elements_3891 : (tensor, tensor<1xi64>) -> tensor + %dim_3892 = tensor.dim %11037, %c0 : tensor + %11038 = arith.index_cast %dim_3892 : index to i64 + %from_elements_3893 = tensor.from_elements %11038, %c1_i64 : tensor<2xi64> + %11039 = stablehlo.dynamic_reshape %11037, %from_elements_3893 : (tensor, tensor<2xi64>) -> tensor + %dim_3894 = tensor.dim %11039, %c0 : tensor + %11040 = arith.index_cast %dim_3894 : index to i64 + %from_elements_3895 = tensor.from_elements %c1_i64, %11040, %c4096_i64 : tensor<3xi64> + %11041 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3895, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3896 = tensor.dim %11041, %c1 : tensor<1x?x4096xi64> + %11042 = arith.index_cast %dim_3896 : index to i64 + %from_elements_3897 = tensor.from_elements %c1_i64, %11042, %c4096_i64, %c1_i64 : tensor<4xi64> + %11043 = stablehlo.dynamic_reshape %11041, %from_elements_3897 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11044 = stablehlo.dynamic_broadcast_in_dim %11039, %from_elements_3895, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3898 = tensor.dim %11044, %c1 : tensor<1x?x4096xi64> + %11045 = arith.index_cast %dim_3898 : index to i64 + %from_elements_3899 = tensor.from_elements %c1_i64, %11045, %c4096_i64, %c1_i64 : tensor<4xi64> + %11046 = stablehlo.dynamic_reshape %11044, %from_elements_3899 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11047 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3895, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3900 = tensor.dim %11047, %c1 : tensor<1x?x4096xi64> + %11048 = arith.index_cast %dim_3900 : index to i64 + %from_elements_3901 = tensor.from_elements %c1_i64, %11048, %c4096_i64, %c1_i64 : tensor<4xi64> + %11049 = stablehlo.dynamic_reshape %11047, %from_elements_3901 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11050 = stablehlo.concatenate %11043, %11046, %11049, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11051 = "stablehlo.gather"(%10595, %11050) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11052 = shape.shape_of %11051 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11053 = shape.num_elements %11052 : tensor<3xindex> -> index + %11054 = stablehlo.compute_reshape_shape %11053, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11055 = stablehlo.dynamic_reshape %11051, %11054 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11056 = stablehlo.dot %11055, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11057 = stablehlo.logistic %11056 : tensor + %11058 = shape.shape_of %11057 : tensor -> tensor<2xindex> + %11059 = shape.shape_of %11056 : tensor -> tensor<2xindex> + %11060 = shape.cstr_broadcastable %11058, %11059 : tensor<2xindex>, tensor<2xindex> + %11061 = shape.assuming %11060 -> (tensor) { + %19688 = shape.broadcast %11058, %11059 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11057, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11056, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11062 = shape.shape_of %11061 : tensor -> tensor<2xindex> + %11063 = shape.cstr_broadcastable %11062, %11059 : tensor<2xindex>, tensor<2xindex> + %11064 = shape.assuming %11063 -> (tensor) { + %19688 = shape.broadcast %11062, %11059 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11061, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11056, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11065 = stablehlo.dot %11064, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3902 = tensor.dim %11037, %c0 : tensor + %11066 = arith.index_cast %dim_3902 : index to i64 + %from_elements_3903 = tensor.from_elements %11066, %c1_i64 : tensor<2xi64> + %11067 = stablehlo.dynamic_reshape %11037, %from_elements_3903 : (tensor, tensor<2xi64>) -> tensor + %dim_3904 = tensor.dim %11034, %c0 : tensor + %11068 = arith.index_cast %dim_3904 : index to i64 + %from_elements_3905 = tensor.from_elements %11068, %c1_i64 : tensor<2xi64> + %11069 = stablehlo.dynamic_reshape %11034, %from_elements_3905 : (tensor, tensor<2xi64>) -> tensor + %11070 = stablehlo.concatenate %11067, %11069, dim = 1 : (tensor, tensor) -> tensor + %11071 = "stablehlo.gather"(%10624, %11070) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11072 = shape.shape_of %11065 : tensor -> tensor<2xindex> + %11073 = shape.shape_of %11071 : tensor -> tensor<2xindex> + %11074 = shape.cstr_broadcastable %11072, %11073 : tensor<2xindex>, tensor<2xindex> + %11075 = shape.assuming %11074 -> (tensor) { + %19688 = shape.broadcast %11072, %11073 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11065, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11071, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11076 = shape.shape_of %11075 : tensor -> tensor<2xindex> + %11077 = stablehlo.dynamic_broadcast_in_dim %11075, %11076, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11078 = stablehlo.dynamic_broadcast_in_dim %213, %11076, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11079 = stablehlo.multiply %11077, %11078 : tensor + %dim_3906 = tensor.dim %11039, %c0 : tensor + %11080 = arith.index_cast %dim_3906 : index to i64 + %dim_3907 = tensor.dim %11075, %c0 : tensor + %11081 = arith.index_cast %dim_3907 : index to i64 + %11082 = arith.maxsi %11080, %11081 : i64 + %11083 = arith.index_cast %11082 : i64 to index + %from_elements_3908 = tensor.from_elements %11083, %c4096 : tensor<2xindex> + %11084 = stablehlo.dynamic_broadcast_in_dim %11039, %from_elements_3908, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3909 = tensor.dim %11084, %c0 : tensor + %11085 = arith.index_cast %dim_3909 : index to i64 + %from_elements_3910 = tensor.from_elements %11085, %c4096_i64 : tensor<2xi64> + %11086 = stablehlo.real_dynamic_slice %11079, %c_22, %from_elements_3910, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3911 = tensor.from_elements %11085, %c4096_i64, %c1_i64 : tensor<3xi64> + %11087 = stablehlo.dynamic_reshape %11084, %from_elements_3911 : (tensor, tensor<3xi64>) -> tensor + %11088 = stablehlo.dynamic_iota %from_elements_3911, dim = 1 : (tensor<3xi64>) -> tensor + %11089 = stablehlo.concatenate %11087, %11088, dim = 2 : (tensor, tensor) -> tensor + %11090 = "stablehlo.scatter"(%11027, %11089, %11086) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11091 = stablehlo.reshape %11090 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %11092 = stablehlo.add %10557, %11091 : tensor<3x1x4096xf32> + %11093 = stablehlo.broadcast_in_dim %11092, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %11094 = stablehlo.power %11093, %15 : tensor<3x1x4096xf32> + %11095 = stablehlo.reduce(%11094 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %11096 = stablehlo.reshape %11095 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %11097 = stablehlo.broadcast_in_dim %11096, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %11098 = stablehlo.divide %11097, %21 : tensor<3x1x1xf32> + %11099 = stablehlo.broadcast_in_dim %11098, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %11100 = stablehlo.add %11099, %25 : tensor<3x1x1xf32> + %11101 = stablehlo.rsqrt %11100 : tensor<3x1x1xf32> + %11102 = stablehlo.broadcast_in_dim %11101, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %11103 = stablehlo.multiply %11093, %11102 : tensor<3x1x4096xf32> + %11104 = stablehlo.broadcast_in_dim %11103, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %11105 = stablehlo.multiply %11104, %31 : tensor<3x1x4096xf32> + %11106 = stablehlo.reshape %11105 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %11107 = stablehlo.dot %11106, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %11108 = stablehlo.reshape %11107 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %11109 = stablehlo.dot %11106, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %11110 = stablehlo.reshape %11109 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %11111 = stablehlo.reshape %11108 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %11112 = stablehlo.transpose %11111, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %11113 = stablehlo.reshape %11110 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %11114 = stablehlo.transpose %11113, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %11115 = stablehlo.slice %arg36 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %11116 = stablehlo.slice %arg37 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %11117 = "stablehlo.gather"(%11115, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %11118 = stablehlo.reshape %11117 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %11119 = "stablehlo.gather"(%11116, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %11120 = stablehlo.reshape %11119 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %11121 = stablehlo.broadcast_in_dim %11112, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %11122 = stablehlo.broadcast_in_dim %11118, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %11123 = stablehlo.multiply %11121, %11122 : tensor<3x32x1x128xf32> + %11124 = stablehlo.slice %11112 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %11125 = stablehlo.slice %11112 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %11126 = stablehlo.negate %11125 : tensor<3x32x1x64xf32> + %11127 = stablehlo.concatenate %11126, %11124, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %11128 = stablehlo.broadcast_in_dim %11127, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %11129 = stablehlo.broadcast_in_dim %11120, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %11130 = stablehlo.multiply %11128, %11129 : tensor<3x32x1x128xf32> + %11131 = stablehlo.add %11123, %11130 : tensor<3x32x1x128xf32> + %11132 = stablehlo.broadcast_in_dim %11114, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %11133 = stablehlo.broadcast_in_dim %11118, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %11134 = stablehlo.multiply %11132, %11133 : tensor<3x8x1x128xf32> + %11135 = stablehlo.slice %11114 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %11136 = stablehlo.slice %11114 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %11137 = stablehlo.negate %11136 : tensor<3x8x1x64xf32> + %11138 = stablehlo.concatenate %11137, %11135, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %11139 = stablehlo.broadcast_in_dim %11138, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %11140 = stablehlo.broadcast_in_dim %11120, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %11141 = stablehlo.multiply %11139, %11140 : tensor<3x8x1x128xf32> + %11142 = stablehlo.add %11134, %11141 : tensor<3x8x1x128xf32> + %11143 = stablehlo.concatenate %arg101, %11142, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %11144 = stablehlo.concatenate %arg102, %11114, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %11145 = stablehlo.reshape %11143 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %11146 = stablehlo.broadcast_in_dim %11145, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %11147 = stablehlo.reshape %11146 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %11148 = stablehlo.reshape %11144 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %11149 = stablehlo.broadcast_in_dim %11148, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %11150 = stablehlo.reshape %11149 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %11151 = stablehlo.transpose %11147, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %11152 = stablehlo.reshape %11131 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %11153 = stablehlo.reshape %11151 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %11154 = stablehlo.broadcast_in_dim %11153, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %11155 = stablehlo.dot_general %11152, %11154, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %11156 = stablehlo.reshape %11155 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %11157 = stablehlo.broadcast_in_dim %11156, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %11158 = stablehlo.divide %11157, %89 : tensor<3x32x1x8xf32> + %11159 = stablehlo.custom_call @byteir.softmax(%11158) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %11160 = stablehlo.reshape %11159 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %11161 = stablehlo.reshape %11150 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %11162 = stablehlo.broadcast_in_dim %11161, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %11163 = stablehlo.dot_general %11160, %11162, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %11164 = stablehlo.reshape %11163 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %11165 = stablehlo.transpose %11164, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %11166 = stablehlo.reshape %11165 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %11167 = stablehlo.reshape %11166 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %11168 = stablehlo.dot %11167, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %11169 = stablehlo.reshape %11168 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %11170 = stablehlo.add %11092, %11169 : tensor<3x1x4096xf32> + %11171 = stablehlo.broadcast_in_dim %11170, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %11172 = stablehlo.power %11171, %15 : tensor<3x1x4096xf32> + %11173 = stablehlo.reduce(%11172 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %11174 = stablehlo.reshape %11173 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %11175 = stablehlo.broadcast_in_dim %11174, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %11176 = stablehlo.divide %11175, %21 : tensor<3x1x1xf32> + %11177 = stablehlo.broadcast_in_dim %11176, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %11178 = stablehlo.add %11177, %25 : tensor<3x1x1xf32> + %11179 = stablehlo.rsqrt %11178 : tensor<3x1x1xf32> + %11180 = stablehlo.broadcast_in_dim %11179, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %11181 = stablehlo.multiply %11171, %11180 : tensor<3x1x4096xf32> + %11182 = stablehlo.broadcast_in_dim %11181, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %11183 = stablehlo.multiply %11182, %31 : tensor<3x1x4096xf32> + %11184 = stablehlo.reshape %11183 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %11185 = stablehlo.dot %11184, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %11186 = stablehlo.custom_call @byteir.softmax(%11185) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %11187:2 = stablehlo.custom_call @byteir.top_k(%11186) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %11188 = stablehlo.reduce(%11187#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %11189 = stablehlo.reshape %11188 : (tensor<3xf32>) -> tensor<3x1xf32> + %11190 = stablehlo.broadcast_in_dim %11187#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %11191 = stablehlo.broadcast_in_dim %11189, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %11192 = stablehlo.divide %11190, %11191 : tensor<3x2xf32> + %11193 = stablehlo.reshape %11187#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %11194 = stablehlo.broadcast_in_dim %11193, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %11195 = stablehlo.compare EQ, %11194, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %11196 = stablehlo.convert %11195 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %11197 = stablehlo.transpose %11196, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %11198 = stablehlo.slice %11197 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11199 = stablehlo.reshape %11198 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11200 = stablehlo.custom_call @byteir.non_zero(%11199) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3912 = tensor.dim %11200, %c0 : tensor + %11201 = arith.index_cast %dim_3912 : index to i64 + %from_elements_3913 = tensor.from_elements %11201, %c1_i64 : tensor<2xi64> + %11202 = stablehlo.real_dynamic_slice %11200, %c_22, %from_elements_3913, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3914 = tensor.dim %11202, %c0 : tensor + %11203 = arith.index_cast %dim_3914 : index to i64 + %from_elements_3915 = tensor.from_elements %11203 : tensor<1xi64> + %11204 = stablehlo.dynamic_reshape %11202, %from_elements_3915 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3916 = tensor.from_elements %11201, %c2_i64 : tensor<2xi64> + %11205 = stablehlo.real_dynamic_slice %11200, %c_24, %from_elements_3916, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3917 = tensor.dim %11205, %c0 : tensor + %11206 = arith.index_cast %dim_3917 : index to i64 + %from_elements_3918 = tensor.from_elements %11206 : tensor<1xi64> + %11207 = stablehlo.dynamic_reshape %11205, %from_elements_3918 : (tensor, tensor<1xi64>) -> tensor + %11208 = stablehlo.reshape %11184 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_3919 = tensor.dim %11207, %c0 : tensor + %11209 = arith.index_cast %dim_3919 : index to i64 + %from_elements_3920 = tensor.from_elements %11209, %c1_i64 : tensor<2xi64> + %11210 = stablehlo.dynamic_reshape %11207, %from_elements_3920 : (tensor, tensor<2xi64>) -> tensor + %dim_3921 = tensor.dim %11210, %c0 : tensor + %11211 = arith.index_cast %dim_3921 : index to i64 + %from_elements_3922 = tensor.from_elements %c1_i64, %11211, %c4096_i64 : tensor<3xi64> + %11212 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3922, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3923 = tensor.dim %11212, %c1 : tensor<1x?x4096xi64> + %11213 = arith.index_cast %dim_3923 : index to i64 + %from_elements_3924 = tensor.from_elements %c1_i64, %11213, %c4096_i64, %c1_i64 : tensor<4xi64> + %11214 = stablehlo.dynamic_reshape %11212, %from_elements_3924 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11215 = stablehlo.dynamic_broadcast_in_dim %11210, %from_elements_3922, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3925 = tensor.dim %11215, %c1 : tensor<1x?x4096xi64> + %11216 = arith.index_cast %dim_3925 : index to i64 + %from_elements_3926 = tensor.from_elements %c1_i64, %11216, %c4096_i64, %c1_i64 : tensor<4xi64> + %11217 = stablehlo.dynamic_reshape %11215, %from_elements_3926 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11218 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3922, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3927 = tensor.dim %11218, %c1 : tensor<1x?x4096xi64> + %11219 = arith.index_cast %dim_3927 : index to i64 + %from_elements_3928 = tensor.from_elements %c1_i64, %11219, %c4096_i64, %c1_i64 : tensor<4xi64> + %11220 = stablehlo.dynamic_reshape %11218, %from_elements_3928 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11221 = stablehlo.concatenate %11214, %11217, %11220, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11222 = "stablehlo.gather"(%11208, %11221) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11223 = shape.shape_of %11222 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11224 = shape.num_elements %11223 : tensor<3xindex> -> index + %11225 = stablehlo.compute_reshape_shape %11224, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11226 = stablehlo.dynamic_reshape %11222, %11225 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11227 = stablehlo.dot %11226, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11228 = stablehlo.logistic %11227 : tensor + %11229 = shape.shape_of %11228 : tensor -> tensor<2xindex> + %11230 = shape.shape_of %11227 : tensor -> tensor<2xindex> + %11231 = shape.cstr_broadcastable %11229, %11230 : tensor<2xindex>, tensor<2xindex> + %11232 = shape.assuming %11231 -> (tensor) { + %19688 = shape.broadcast %11229, %11230 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11228, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11227, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11233 = shape.shape_of %11232 : tensor -> tensor<2xindex> + %11234 = shape.cstr_broadcastable %11233, %11230 : tensor<2xindex>, tensor<2xindex> + %11235 = shape.assuming %11234 -> (tensor) { + %19688 = shape.broadcast %11233, %11230 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11232, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11227, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11236 = stablehlo.dot %11235, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %11237 = stablehlo.reshape %11192 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_3929 = tensor.dim %11207, %c0 : tensor + %11238 = arith.index_cast %dim_3929 : index to i64 + %from_elements_3930 = tensor.from_elements %11238, %c1_i64 : tensor<2xi64> + %11239 = stablehlo.dynamic_reshape %11207, %from_elements_3930 : (tensor, tensor<2xi64>) -> tensor + %dim_3931 = tensor.dim %11204, %c0 : tensor + %11240 = arith.index_cast %dim_3931 : index to i64 + %from_elements_3932 = tensor.from_elements %11240, %c1_i64 : tensor<2xi64> + %11241 = stablehlo.dynamic_reshape %11204, %from_elements_3932 : (tensor, tensor<2xi64>) -> tensor + %11242 = stablehlo.concatenate %11239, %11241, dim = 1 : (tensor, tensor) -> tensor + %11243 = "stablehlo.gather"(%11237, %11242) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11244 = shape.shape_of %11236 : tensor -> tensor<2xindex> + %11245 = shape.shape_of %11243 : tensor -> tensor<2xindex> + %11246 = shape.cstr_broadcastable %11244, %11245 : tensor<2xindex>, tensor<2xindex> + %11247 = shape.assuming %11246 -> (tensor) { + %19688 = shape.broadcast %11244, %11245 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11236, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11243, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11248 = shape.shape_of %11247 : tensor -> tensor<2xindex> + %11249 = stablehlo.dynamic_broadcast_in_dim %11247, %11248, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11250 = stablehlo.dynamic_broadcast_in_dim %213, %11248, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11251 = stablehlo.multiply %11249, %11250 : tensor + %dim_3933 = tensor.dim %11210, %c0 : tensor + %11252 = arith.index_cast %dim_3933 : index to i64 + %dim_3934 = tensor.dim %11247, %c0 : tensor + %11253 = arith.index_cast %dim_3934 : index to i64 + %11254 = arith.maxsi %11252, %11253 : i64 + %11255 = arith.index_cast %11254 : i64 to index + %from_elements_3935 = tensor.from_elements %11255, %c4096 : tensor<2xindex> + %11256 = stablehlo.dynamic_broadcast_in_dim %11210, %from_elements_3935, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3936 = tensor.dim %11256, %c0 : tensor + %11257 = arith.index_cast %dim_3936 : index to i64 + %from_elements_3937 = tensor.from_elements %11257, %c4096_i64 : tensor<2xi64> + %11258 = stablehlo.real_dynamic_slice %11251, %c_22, %from_elements_3937, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3938 = tensor.from_elements %11257, %c4096_i64, %c1_i64 : tensor<3xi64> + %11259 = stablehlo.dynamic_reshape %11256, %from_elements_3938 : (tensor, tensor<3xi64>) -> tensor + %11260 = stablehlo.dynamic_iota %from_elements_3938, dim = 1 : (tensor<3xi64>) -> tensor + %11261 = stablehlo.concatenate %11259, %11260, dim = 2 : (tensor, tensor) -> tensor + %11262 = "stablehlo.scatter"(%cst_2, %11261, %11258) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11263 = stablehlo.slice %11197 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11264 = stablehlo.reshape %11263 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11265 = stablehlo.custom_call @byteir.non_zero(%11264) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3939 = tensor.dim %11265, %c0 : tensor + %11266 = arith.index_cast %dim_3939 : index to i64 + %from_elements_3940 = tensor.from_elements %11266, %c1_i64 : tensor<2xi64> + %11267 = stablehlo.real_dynamic_slice %11265, %c_22, %from_elements_3940, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3941 = tensor.dim %11267, %c0 : tensor + %11268 = arith.index_cast %dim_3941 : index to i64 + %from_elements_3942 = tensor.from_elements %11268 : tensor<1xi64> + %11269 = stablehlo.dynamic_reshape %11267, %from_elements_3942 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3943 = tensor.from_elements %11266, %c2_i64 : tensor<2xi64> + %11270 = stablehlo.real_dynamic_slice %11265, %c_24, %from_elements_3943, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3944 = tensor.dim %11270, %c0 : tensor + %11271 = arith.index_cast %dim_3944 : index to i64 + %from_elements_3945 = tensor.from_elements %11271 : tensor<1xi64> + %11272 = stablehlo.dynamic_reshape %11270, %from_elements_3945 : (tensor, tensor<1xi64>) -> tensor + %dim_3946 = tensor.dim %11272, %c0 : tensor + %11273 = arith.index_cast %dim_3946 : index to i64 + %from_elements_3947 = tensor.from_elements %11273, %c1_i64 : tensor<2xi64> + %11274 = stablehlo.dynamic_reshape %11272, %from_elements_3947 : (tensor, tensor<2xi64>) -> tensor + %dim_3948 = tensor.dim %11274, %c0 : tensor + %11275 = arith.index_cast %dim_3948 : index to i64 + %from_elements_3949 = tensor.from_elements %c1_i64, %11275, %c4096_i64 : tensor<3xi64> + %11276 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3949, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3950 = tensor.dim %11276, %c1 : tensor<1x?x4096xi64> + %11277 = arith.index_cast %dim_3950 : index to i64 + %from_elements_3951 = tensor.from_elements %c1_i64, %11277, %c4096_i64, %c1_i64 : tensor<4xi64> + %11278 = stablehlo.dynamic_reshape %11276, %from_elements_3951 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11279 = stablehlo.dynamic_broadcast_in_dim %11274, %from_elements_3949, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3952 = tensor.dim %11279, %c1 : tensor<1x?x4096xi64> + %11280 = arith.index_cast %dim_3952 : index to i64 + %from_elements_3953 = tensor.from_elements %c1_i64, %11280, %c4096_i64, %c1_i64 : tensor<4xi64> + %11281 = stablehlo.dynamic_reshape %11279, %from_elements_3953 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11282 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3949, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3954 = tensor.dim %11282, %c1 : tensor<1x?x4096xi64> + %11283 = arith.index_cast %dim_3954 : index to i64 + %from_elements_3955 = tensor.from_elements %c1_i64, %11283, %c4096_i64, %c1_i64 : tensor<4xi64> + %11284 = stablehlo.dynamic_reshape %11282, %from_elements_3955 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11285 = stablehlo.concatenate %11278, %11281, %11284, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11286 = "stablehlo.gather"(%11208, %11285) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11287 = shape.shape_of %11286 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11288 = shape.num_elements %11287 : tensor<3xindex> -> index + %11289 = stablehlo.compute_reshape_shape %11288, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11290 = stablehlo.dynamic_reshape %11286, %11289 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11291 = stablehlo.dot %11290, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11292 = stablehlo.logistic %11291 : tensor + %11293 = shape.shape_of %11292 : tensor -> tensor<2xindex> + %11294 = shape.shape_of %11291 : tensor -> tensor<2xindex> + %11295 = shape.cstr_broadcastable %11293, %11294 : tensor<2xindex>, tensor<2xindex> + %11296 = shape.assuming %11295 -> (tensor) { + %19688 = shape.broadcast %11293, %11294 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11292, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11291, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11297 = shape.shape_of %11296 : tensor -> tensor<2xindex> + %11298 = shape.cstr_broadcastable %11297, %11294 : tensor<2xindex>, tensor<2xindex> + %11299 = shape.assuming %11298 -> (tensor) { + %19688 = shape.broadcast %11297, %11294 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11296, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11291, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11300 = stablehlo.dot %11299, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3956 = tensor.dim %11272, %c0 : tensor + %11301 = arith.index_cast %dim_3956 : index to i64 + %from_elements_3957 = tensor.from_elements %11301, %c1_i64 : tensor<2xi64> + %11302 = stablehlo.dynamic_reshape %11272, %from_elements_3957 : (tensor, tensor<2xi64>) -> tensor + %dim_3958 = tensor.dim %11269, %c0 : tensor + %11303 = arith.index_cast %dim_3958 : index to i64 + %from_elements_3959 = tensor.from_elements %11303, %c1_i64 : tensor<2xi64> + %11304 = stablehlo.dynamic_reshape %11269, %from_elements_3959 : (tensor, tensor<2xi64>) -> tensor + %11305 = stablehlo.concatenate %11302, %11304, dim = 1 : (tensor, tensor) -> tensor + %11306 = "stablehlo.gather"(%11237, %11305) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11307 = shape.shape_of %11300 : tensor -> tensor<2xindex> + %11308 = shape.shape_of %11306 : tensor -> tensor<2xindex> + %11309 = shape.cstr_broadcastable %11307, %11308 : tensor<2xindex>, tensor<2xindex> + %11310 = shape.assuming %11309 -> (tensor) { + %19688 = shape.broadcast %11307, %11308 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11300, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11306, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11311 = shape.shape_of %11310 : tensor -> tensor<2xindex> + %11312 = stablehlo.dynamic_broadcast_in_dim %11310, %11311, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11313 = stablehlo.dynamic_broadcast_in_dim %213, %11311, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11314 = stablehlo.multiply %11312, %11313 : tensor + %dim_3960 = tensor.dim %11274, %c0 : tensor + %11315 = arith.index_cast %dim_3960 : index to i64 + %dim_3961 = tensor.dim %11310, %c0 : tensor + %11316 = arith.index_cast %dim_3961 : index to i64 + %11317 = arith.maxsi %11315, %11316 : i64 + %11318 = arith.index_cast %11317 : i64 to index + %from_elements_3962 = tensor.from_elements %11318, %c4096 : tensor<2xindex> + %11319 = stablehlo.dynamic_broadcast_in_dim %11274, %from_elements_3962, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3963 = tensor.dim %11319, %c0 : tensor + %11320 = arith.index_cast %dim_3963 : index to i64 + %from_elements_3964 = tensor.from_elements %11320, %c4096_i64 : tensor<2xi64> + %11321 = stablehlo.real_dynamic_slice %11314, %c_22, %from_elements_3964, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3965 = tensor.from_elements %11320, %c4096_i64, %c1_i64 : tensor<3xi64> + %11322 = stablehlo.dynamic_reshape %11319, %from_elements_3965 : (tensor, tensor<3xi64>) -> tensor + %11323 = stablehlo.dynamic_iota %from_elements_3965, dim = 1 : (tensor<3xi64>) -> tensor + %11324 = stablehlo.concatenate %11322, %11323, dim = 2 : (tensor, tensor) -> tensor + %11325 = "stablehlo.scatter"(%11262, %11324, %11321) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11326 = stablehlo.slice %11197 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11327 = stablehlo.reshape %11326 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11328 = stablehlo.custom_call @byteir.non_zero(%11327) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3966 = tensor.dim %11328, %c0 : tensor + %11329 = arith.index_cast %dim_3966 : index to i64 + %from_elements_3967 = tensor.from_elements %11329, %c1_i64 : tensor<2xi64> + %11330 = stablehlo.real_dynamic_slice %11328, %c_22, %from_elements_3967, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3968 = tensor.dim %11330, %c0 : tensor + %11331 = arith.index_cast %dim_3968 : index to i64 + %from_elements_3969 = tensor.from_elements %11331 : tensor<1xi64> + %11332 = stablehlo.dynamic_reshape %11330, %from_elements_3969 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3970 = tensor.from_elements %11329, %c2_i64 : tensor<2xi64> + %11333 = stablehlo.real_dynamic_slice %11328, %c_24, %from_elements_3970, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3971 = tensor.dim %11333, %c0 : tensor + %11334 = arith.index_cast %dim_3971 : index to i64 + %from_elements_3972 = tensor.from_elements %11334 : tensor<1xi64> + %11335 = stablehlo.dynamic_reshape %11333, %from_elements_3972 : (tensor, tensor<1xi64>) -> tensor + %dim_3973 = tensor.dim %11335, %c0 : tensor + %11336 = arith.index_cast %dim_3973 : index to i64 + %from_elements_3974 = tensor.from_elements %11336, %c1_i64 : tensor<2xi64> + %11337 = stablehlo.dynamic_reshape %11335, %from_elements_3974 : (tensor, tensor<2xi64>) -> tensor + %dim_3975 = tensor.dim %11337, %c0 : tensor + %11338 = arith.index_cast %dim_3975 : index to i64 + %from_elements_3976 = tensor.from_elements %c1_i64, %11338, %c4096_i64 : tensor<3xi64> + %11339 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_3976, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3977 = tensor.dim %11339, %c1 : tensor<1x?x4096xi64> + %11340 = arith.index_cast %dim_3977 : index to i64 + %from_elements_3978 = tensor.from_elements %c1_i64, %11340, %c4096_i64, %c1_i64 : tensor<4xi64> + %11341 = stablehlo.dynamic_reshape %11339, %from_elements_3978 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11342 = stablehlo.dynamic_broadcast_in_dim %11337, %from_elements_3976, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3979 = tensor.dim %11342, %c1 : tensor<1x?x4096xi64> + %11343 = arith.index_cast %dim_3979 : index to i64 + %from_elements_3980 = tensor.from_elements %c1_i64, %11343, %c4096_i64, %c1_i64 : tensor<4xi64> + %11344 = stablehlo.dynamic_reshape %11342, %from_elements_3980 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11345 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_3976, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_3981 = tensor.dim %11345, %c1 : tensor<1x?x4096xi64> + %11346 = arith.index_cast %dim_3981 : index to i64 + %from_elements_3982 = tensor.from_elements %c1_i64, %11346, %c4096_i64, %c1_i64 : tensor<4xi64> + %11347 = stablehlo.dynamic_reshape %11345, %from_elements_3982 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11348 = stablehlo.concatenate %11341, %11344, %11347, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11349 = "stablehlo.gather"(%11208, %11348) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11350 = shape.shape_of %11349 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11351 = shape.num_elements %11350 : tensor<3xindex> -> index + %11352 = stablehlo.compute_reshape_shape %11351, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11353 = stablehlo.dynamic_reshape %11349, %11352 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11354 = stablehlo.dot %11353, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11355 = stablehlo.logistic %11354 : tensor + %11356 = shape.shape_of %11355 : tensor -> tensor<2xindex> + %11357 = shape.shape_of %11354 : tensor -> tensor<2xindex> + %11358 = shape.cstr_broadcastable %11356, %11357 : tensor<2xindex>, tensor<2xindex> + %11359 = shape.assuming %11358 -> (tensor) { + %19688 = shape.broadcast %11356, %11357 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11355, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11354, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11360 = shape.shape_of %11359 : tensor -> tensor<2xindex> + %11361 = shape.cstr_broadcastable %11360, %11357 : tensor<2xindex>, tensor<2xindex> + %11362 = shape.assuming %11361 -> (tensor) { + %19688 = shape.broadcast %11360, %11357 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11359, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11354, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11363 = stablehlo.dot %11362, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_3983 = tensor.dim %11335, %c0 : tensor + %11364 = arith.index_cast %dim_3983 : index to i64 + %from_elements_3984 = tensor.from_elements %11364, %c1_i64 : tensor<2xi64> + %11365 = stablehlo.dynamic_reshape %11335, %from_elements_3984 : (tensor, tensor<2xi64>) -> tensor + %dim_3985 = tensor.dim %11332, %c0 : tensor + %11366 = arith.index_cast %dim_3985 : index to i64 + %from_elements_3986 = tensor.from_elements %11366, %c1_i64 : tensor<2xi64> + %11367 = stablehlo.dynamic_reshape %11332, %from_elements_3986 : (tensor, tensor<2xi64>) -> tensor + %11368 = stablehlo.concatenate %11365, %11367, dim = 1 : (tensor, tensor) -> tensor + %11369 = "stablehlo.gather"(%11237, %11368) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11370 = shape.shape_of %11363 : tensor -> tensor<2xindex> + %11371 = shape.shape_of %11369 : tensor -> tensor<2xindex> + %11372 = shape.cstr_broadcastable %11370, %11371 : tensor<2xindex>, tensor<2xindex> + %11373 = shape.assuming %11372 -> (tensor) { + %19688 = shape.broadcast %11370, %11371 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11363, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11369, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11374 = shape.shape_of %11373 : tensor -> tensor<2xindex> + %11375 = stablehlo.dynamic_broadcast_in_dim %11373, %11374, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11376 = stablehlo.dynamic_broadcast_in_dim %213, %11374, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11377 = stablehlo.multiply %11375, %11376 : tensor + %dim_3987 = tensor.dim %11337, %c0 : tensor + %11378 = arith.index_cast %dim_3987 : index to i64 + %dim_3988 = tensor.dim %11373, %c0 : tensor + %11379 = arith.index_cast %dim_3988 : index to i64 + %11380 = arith.maxsi %11378, %11379 : i64 + %11381 = arith.index_cast %11380 : i64 to index + %from_elements_3989 = tensor.from_elements %11381, %c4096 : tensor<2xindex> + %11382 = stablehlo.dynamic_broadcast_in_dim %11337, %from_elements_3989, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_3990 = tensor.dim %11382, %c0 : tensor + %11383 = arith.index_cast %dim_3990 : index to i64 + %from_elements_3991 = tensor.from_elements %11383, %c4096_i64 : tensor<2xi64> + %11384 = stablehlo.real_dynamic_slice %11377, %c_22, %from_elements_3991, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_3992 = tensor.from_elements %11383, %c4096_i64, %c1_i64 : tensor<3xi64> + %11385 = stablehlo.dynamic_reshape %11382, %from_elements_3992 : (tensor, tensor<3xi64>) -> tensor + %11386 = stablehlo.dynamic_iota %from_elements_3992, dim = 1 : (tensor<3xi64>) -> tensor + %11387 = stablehlo.concatenate %11385, %11386, dim = 2 : (tensor, tensor) -> tensor + %11388 = "stablehlo.scatter"(%11325, %11387, %11384) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11389 = stablehlo.slice %11197 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11390 = stablehlo.reshape %11389 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11391 = stablehlo.custom_call @byteir.non_zero(%11390) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_3993 = tensor.dim %11391, %c0 : tensor + %11392 = arith.index_cast %dim_3993 : index to i64 + %from_elements_3994 = tensor.from_elements %11392, %c1_i64 : tensor<2xi64> + %11393 = stablehlo.real_dynamic_slice %11391, %c_22, %from_elements_3994, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3995 = tensor.dim %11393, %c0 : tensor + %11394 = arith.index_cast %dim_3995 : index to i64 + %from_elements_3996 = tensor.from_elements %11394 : tensor<1xi64> + %11395 = stablehlo.dynamic_reshape %11393, %from_elements_3996 : (tensor, tensor<1xi64>) -> tensor + %from_elements_3997 = tensor.from_elements %11392, %c2_i64 : tensor<2xi64> + %11396 = stablehlo.real_dynamic_slice %11391, %c_24, %from_elements_3997, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_3998 = tensor.dim %11396, %c0 : tensor + %11397 = arith.index_cast %dim_3998 : index to i64 + %from_elements_3999 = tensor.from_elements %11397 : tensor<1xi64> + %11398 = stablehlo.dynamic_reshape %11396, %from_elements_3999 : (tensor, tensor<1xi64>) -> tensor + %dim_4000 = tensor.dim %11398, %c0 : tensor + %11399 = arith.index_cast %dim_4000 : index to i64 + %from_elements_4001 = tensor.from_elements %11399, %c1_i64 : tensor<2xi64> + %11400 = stablehlo.dynamic_reshape %11398, %from_elements_4001 : (tensor, tensor<2xi64>) -> tensor + %dim_4002 = tensor.dim %11400, %c0 : tensor + %11401 = arith.index_cast %dim_4002 : index to i64 + %from_elements_4003 = tensor.from_elements %c1_i64, %11401, %c4096_i64 : tensor<3xi64> + %11402 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4003, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4004 = tensor.dim %11402, %c1 : tensor<1x?x4096xi64> + %11403 = arith.index_cast %dim_4004 : index to i64 + %from_elements_4005 = tensor.from_elements %c1_i64, %11403, %c4096_i64, %c1_i64 : tensor<4xi64> + %11404 = stablehlo.dynamic_reshape %11402, %from_elements_4005 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11405 = stablehlo.dynamic_broadcast_in_dim %11400, %from_elements_4003, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4006 = tensor.dim %11405, %c1 : tensor<1x?x4096xi64> + %11406 = arith.index_cast %dim_4006 : index to i64 + %from_elements_4007 = tensor.from_elements %c1_i64, %11406, %c4096_i64, %c1_i64 : tensor<4xi64> + %11407 = stablehlo.dynamic_reshape %11405, %from_elements_4007 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11408 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4003, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4008 = tensor.dim %11408, %c1 : tensor<1x?x4096xi64> + %11409 = arith.index_cast %dim_4008 : index to i64 + %from_elements_4009 = tensor.from_elements %c1_i64, %11409, %c4096_i64, %c1_i64 : tensor<4xi64> + %11410 = stablehlo.dynamic_reshape %11408, %from_elements_4009 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11411 = stablehlo.concatenate %11404, %11407, %11410, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11412 = "stablehlo.gather"(%11208, %11411) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11413 = shape.shape_of %11412 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11414 = shape.num_elements %11413 : tensor<3xindex> -> index + %11415 = stablehlo.compute_reshape_shape %11414, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11416 = stablehlo.dynamic_reshape %11412, %11415 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11417 = stablehlo.dot %11416, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11418 = stablehlo.logistic %11417 : tensor + %11419 = shape.shape_of %11418 : tensor -> tensor<2xindex> + %11420 = shape.shape_of %11417 : tensor -> tensor<2xindex> + %11421 = shape.cstr_broadcastable %11419, %11420 : tensor<2xindex>, tensor<2xindex> + %11422 = shape.assuming %11421 -> (tensor) { + %19688 = shape.broadcast %11419, %11420 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11418, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11417, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11423 = shape.shape_of %11422 : tensor -> tensor<2xindex> + %11424 = shape.cstr_broadcastable %11423, %11420 : tensor<2xindex>, tensor<2xindex> + %11425 = shape.assuming %11424 -> (tensor) { + %19688 = shape.broadcast %11423, %11420 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11422, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11417, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11426 = stablehlo.dot %11425, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4010 = tensor.dim %11398, %c0 : tensor + %11427 = arith.index_cast %dim_4010 : index to i64 + %from_elements_4011 = tensor.from_elements %11427, %c1_i64 : tensor<2xi64> + %11428 = stablehlo.dynamic_reshape %11398, %from_elements_4011 : (tensor, tensor<2xi64>) -> tensor + %dim_4012 = tensor.dim %11395, %c0 : tensor + %11429 = arith.index_cast %dim_4012 : index to i64 + %from_elements_4013 = tensor.from_elements %11429, %c1_i64 : tensor<2xi64> + %11430 = stablehlo.dynamic_reshape %11395, %from_elements_4013 : (tensor, tensor<2xi64>) -> tensor + %11431 = stablehlo.concatenate %11428, %11430, dim = 1 : (tensor, tensor) -> tensor + %11432 = "stablehlo.gather"(%11237, %11431) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11433 = shape.shape_of %11426 : tensor -> tensor<2xindex> + %11434 = shape.shape_of %11432 : tensor -> tensor<2xindex> + %11435 = shape.cstr_broadcastable %11433, %11434 : tensor<2xindex>, tensor<2xindex> + %11436 = shape.assuming %11435 -> (tensor) { + %19688 = shape.broadcast %11433, %11434 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11426, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11432, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11437 = shape.shape_of %11436 : tensor -> tensor<2xindex> + %11438 = stablehlo.dynamic_broadcast_in_dim %11436, %11437, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11439 = stablehlo.dynamic_broadcast_in_dim %213, %11437, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11440 = stablehlo.multiply %11438, %11439 : tensor + %dim_4014 = tensor.dim %11400, %c0 : tensor + %11441 = arith.index_cast %dim_4014 : index to i64 + %dim_4015 = tensor.dim %11436, %c0 : tensor + %11442 = arith.index_cast %dim_4015 : index to i64 + %11443 = arith.maxsi %11441, %11442 : i64 + %11444 = arith.index_cast %11443 : i64 to index + %from_elements_4016 = tensor.from_elements %11444, %c4096 : tensor<2xindex> + %11445 = stablehlo.dynamic_broadcast_in_dim %11400, %from_elements_4016, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4017 = tensor.dim %11445, %c0 : tensor + %11446 = arith.index_cast %dim_4017 : index to i64 + %from_elements_4018 = tensor.from_elements %11446, %c4096_i64 : tensor<2xi64> + %11447 = stablehlo.real_dynamic_slice %11440, %c_22, %from_elements_4018, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4019 = tensor.from_elements %11446, %c4096_i64, %c1_i64 : tensor<3xi64> + %11448 = stablehlo.dynamic_reshape %11445, %from_elements_4019 : (tensor, tensor<3xi64>) -> tensor + %11449 = stablehlo.dynamic_iota %from_elements_4019, dim = 1 : (tensor<3xi64>) -> tensor + %11450 = stablehlo.concatenate %11448, %11449, dim = 2 : (tensor, tensor) -> tensor + %11451 = "stablehlo.scatter"(%11388, %11450, %11447) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11452 = stablehlo.slice %11197 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11453 = stablehlo.reshape %11452 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11454 = stablehlo.custom_call @byteir.non_zero(%11453) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4020 = tensor.dim %11454, %c0 : tensor + %11455 = arith.index_cast %dim_4020 : index to i64 + %from_elements_4021 = tensor.from_elements %11455, %c1_i64 : tensor<2xi64> + %11456 = stablehlo.real_dynamic_slice %11454, %c_22, %from_elements_4021, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4022 = tensor.dim %11456, %c0 : tensor + %11457 = arith.index_cast %dim_4022 : index to i64 + %from_elements_4023 = tensor.from_elements %11457 : tensor<1xi64> + %11458 = stablehlo.dynamic_reshape %11456, %from_elements_4023 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4024 = tensor.from_elements %11455, %c2_i64 : tensor<2xi64> + %11459 = stablehlo.real_dynamic_slice %11454, %c_24, %from_elements_4024, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4025 = tensor.dim %11459, %c0 : tensor + %11460 = arith.index_cast %dim_4025 : index to i64 + %from_elements_4026 = tensor.from_elements %11460 : tensor<1xi64> + %11461 = stablehlo.dynamic_reshape %11459, %from_elements_4026 : (tensor, tensor<1xi64>) -> tensor + %dim_4027 = tensor.dim %11461, %c0 : tensor + %11462 = arith.index_cast %dim_4027 : index to i64 + %from_elements_4028 = tensor.from_elements %11462, %c1_i64 : tensor<2xi64> + %11463 = stablehlo.dynamic_reshape %11461, %from_elements_4028 : (tensor, tensor<2xi64>) -> tensor + %dim_4029 = tensor.dim %11463, %c0 : tensor + %11464 = arith.index_cast %dim_4029 : index to i64 + %from_elements_4030 = tensor.from_elements %c1_i64, %11464, %c4096_i64 : tensor<3xi64> + %11465 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4030, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4031 = tensor.dim %11465, %c1 : tensor<1x?x4096xi64> + %11466 = arith.index_cast %dim_4031 : index to i64 + %from_elements_4032 = tensor.from_elements %c1_i64, %11466, %c4096_i64, %c1_i64 : tensor<4xi64> + %11467 = stablehlo.dynamic_reshape %11465, %from_elements_4032 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11468 = stablehlo.dynamic_broadcast_in_dim %11463, %from_elements_4030, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4033 = tensor.dim %11468, %c1 : tensor<1x?x4096xi64> + %11469 = arith.index_cast %dim_4033 : index to i64 + %from_elements_4034 = tensor.from_elements %c1_i64, %11469, %c4096_i64, %c1_i64 : tensor<4xi64> + %11470 = stablehlo.dynamic_reshape %11468, %from_elements_4034 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11471 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4030, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4035 = tensor.dim %11471, %c1 : tensor<1x?x4096xi64> + %11472 = arith.index_cast %dim_4035 : index to i64 + %from_elements_4036 = tensor.from_elements %c1_i64, %11472, %c4096_i64, %c1_i64 : tensor<4xi64> + %11473 = stablehlo.dynamic_reshape %11471, %from_elements_4036 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11474 = stablehlo.concatenate %11467, %11470, %11473, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11475 = "stablehlo.gather"(%11208, %11474) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11476 = shape.shape_of %11475 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11477 = shape.num_elements %11476 : tensor<3xindex> -> index + %11478 = stablehlo.compute_reshape_shape %11477, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11479 = stablehlo.dynamic_reshape %11475, %11478 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11480 = stablehlo.dot %11479, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11481 = stablehlo.logistic %11480 : tensor + %11482 = shape.shape_of %11481 : tensor -> tensor<2xindex> + %11483 = shape.shape_of %11480 : tensor -> tensor<2xindex> + %11484 = shape.cstr_broadcastable %11482, %11483 : tensor<2xindex>, tensor<2xindex> + %11485 = shape.assuming %11484 -> (tensor) { + %19688 = shape.broadcast %11482, %11483 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11481, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11480, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11486 = shape.shape_of %11485 : tensor -> tensor<2xindex> + %11487 = shape.cstr_broadcastable %11486, %11483 : tensor<2xindex>, tensor<2xindex> + %11488 = shape.assuming %11487 -> (tensor) { + %19688 = shape.broadcast %11486, %11483 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11485, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11480, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11489 = stablehlo.dot %11488, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4037 = tensor.dim %11461, %c0 : tensor + %11490 = arith.index_cast %dim_4037 : index to i64 + %from_elements_4038 = tensor.from_elements %11490, %c1_i64 : tensor<2xi64> + %11491 = stablehlo.dynamic_reshape %11461, %from_elements_4038 : (tensor, tensor<2xi64>) -> tensor + %dim_4039 = tensor.dim %11458, %c0 : tensor + %11492 = arith.index_cast %dim_4039 : index to i64 + %from_elements_4040 = tensor.from_elements %11492, %c1_i64 : tensor<2xi64> + %11493 = stablehlo.dynamic_reshape %11458, %from_elements_4040 : (tensor, tensor<2xi64>) -> tensor + %11494 = stablehlo.concatenate %11491, %11493, dim = 1 : (tensor, tensor) -> tensor + %11495 = "stablehlo.gather"(%11237, %11494) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11496 = shape.shape_of %11489 : tensor -> tensor<2xindex> + %11497 = shape.shape_of %11495 : tensor -> tensor<2xindex> + %11498 = shape.cstr_broadcastable %11496, %11497 : tensor<2xindex>, tensor<2xindex> + %11499 = shape.assuming %11498 -> (tensor) { + %19688 = shape.broadcast %11496, %11497 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11489, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11495, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11500 = shape.shape_of %11499 : tensor -> tensor<2xindex> + %11501 = stablehlo.dynamic_broadcast_in_dim %11499, %11500, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11502 = stablehlo.dynamic_broadcast_in_dim %213, %11500, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11503 = stablehlo.multiply %11501, %11502 : tensor + %dim_4041 = tensor.dim %11463, %c0 : tensor + %11504 = arith.index_cast %dim_4041 : index to i64 + %dim_4042 = tensor.dim %11499, %c0 : tensor + %11505 = arith.index_cast %dim_4042 : index to i64 + %11506 = arith.maxsi %11504, %11505 : i64 + %11507 = arith.index_cast %11506 : i64 to index + %from_elements_4043 = tensor.from_elements %11507, %c4096 : tensor<2xindex> + %11508 = stablehlo.dynamic_broadcast_in_dim %11463, %from_elements_4043, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4044 = tensor.dim %11508, %c0 : tensor + %11509 = arith.index_cast %dim_4044 : index to i64 + %from_elements_4045 = tensor.from_elements %11509, %c4096_i64 : tensor<2xi64> + %11510 = stablehlo.real_dynamic_slice %11503, %c_22, %from_elements_4045, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4046 = tensor.from_elements %11509, %c4096_i64, %c1_i64 : tensor<3xi64> + %11511 = stablehlo.dynamic_reshape %11508, %from_elements_4046 : (tensor, tensor<3xi64>) -> tensor + %11512 = stablehlo.dynamic_iota %from_elements_4046, dim = 1 : (tensor<3xi64>) -> tensor + %11513 = stablehlo.concatenate %11511, %11512, dim = 2 : (tensor, tensor) -> tensor + %11514 = "stablehlo.scatter"(%11451, %11513, %11510) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11515 = stablehlo.slice %11197 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11516 = stablehlo.reshape %11515 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11517 = stablehlo.custom_call @byteir.non_zero(%11516) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4047 = tensor.dim %11517, %c0 : tensor + %11518 = arith.index_cast %dim_4047 : index to i64 + %from_elements_4048 = tensor.from_elements %11518, %c1_i64 : tensor<2xi64> + %11519 = stablehlo.real_dynamic_slice %11517, %c_22, %from_elements_4048, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4049 = tensor.dim %11519, %c0 : tensor + %11520 = arith.index_cast %dim_4049 : index to i64 + %from_elements_4050 = tensor.from_elements %11520 : tensor<1xi64> + %11521 = stablehlo.dynamic_reshape %11519, %from_elements_4050 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4051 = tensor.from_elements %11518, %c2_i64 : tensor<2xi64> + %11522 = stablehlo.real_dynamic_slice %11517, %c_24, %from_elements_4051, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4052 = tensor.dim %11522, %c0 : tensor + %11523 = arith.index_cast %dim_4052 : index to i64 + %from_elements_4053 = tensor.from_elements %11523 : tensor<1xi64> + %11524 = stablehlo.dynamic_reshape %11522, %from_elements_4053 : (tensor, tensor<1xi64>) -> tensor + %dim_4054 = tensor.dim %11524, %c0 : tensor + %11525 = arith.index_cast %dim_4054 : index to i64 + %from_elements_4055 = tensor.from_elements %11525, %c1_i64 : tensor<2xi64> + %11526 = stablehlo.dynamic_reshape %11524, %from_elements_4055 : (tensor, tensor<2xi64>) -> tensor + %dim_4056 = tensor.dim %11526, %c0 : tensor + %11527 = arith.index_cast %dim_4056 : index to i64 + %from_elements_4057 = tensor.from_elements %c1_i64, %11527, %c4096_i64 : tensor<3xi64> + %11528 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4057, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4058 = tensor.dim %11528, %c1 : tensor<1x?x4096xi64> + %11529 = arith.index_cast %dim_4058 : index to i64 + %from_elements_4059 = tensor.from_elements %c1_i64, %11529, %c4096_i64, %c1_i64 : tensor<4xi64> + %11530 = stablehlo.dynamic_reshape %11528, %from_elements_4059 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11531 = stablehlo.dynamic_broadcast_in_dim %11526, %from_elements_4057, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4060 = tensor.dim %11531, %c1 : tensor<1x?x4096xi64> + %11532 = arith.index_cast %dim_4060 : index to i64 + %from_elements_4061 = tensor.from_elements %c1_i64, %11532, %c4096_i64, %c1_i64 : tensor<4xi64> + %11533 = stablehlo.dynamic_reshape %11531, %from_elements_4061 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11534 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4057, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4062 = tensor.dim %11534, %c1 : tensor<1x?x4096xi64> + %11535 = arith.index_cast %dim_4062 : index to i64 + %from_elements_4063 = tensor.from_elements %c1_i64, %11535, %c4096_i64, %c1_i64 : tensor<4xi64> + %11536 = stablehlo.dynamic_reshape %11534, %from_elements_4063 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11537 = stablehlo.concatenate %11530, %11533, %11536, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11538 = "stablehlo.gather"(%11208, %11537) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11539 = shape.shape_of %11538 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11540 = shape.num_elements %11539 : tensor<3xindex> -> index + %11541 = stablehlo.compute_reshape_shape %11540, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11542 = stablehlo.dynamic_reshape %11538, %11541 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11543 = stablehlo.dot %11542, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11544 = stablehlo.logistic %11543 : tensor + %11545 = shape.shape_of %11544 : tensor -> tensor<2xindex> + %11546 = shape.shape_of %11543 : tensor -> tensor<2xindex> + %11547 = shape.cstr_broadcastable %11545, %11546 : tensor<2xindex>, tensor<2xindex> + %11548 = shape.assuming %11547 -> (tensor) { + %19688 = shape.broadcast %11545, %11546 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11544, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11543, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11549 = shape.shape_of %11548 : tensor -> tensor<2xindex> + %11550 = shape.cstr_broadcastable %11549, %11546 : tensor<2xindex>, tensor<2xindex> + %11551 = shape.assuming %11550 -> (tensor) { + %19688 = shape.broadcast %11549, %11546 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11548, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11543, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11552 = stablehlo.dot %11551, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4064 = tensor.dim %11524, %c0 : tensor + %11553 = arith.index_cast %dim_4064 : index to i64 + %from_elements_4065 = tensor.from_elements %11553, %c1_i64 : tensor<2xi64> + %11554 = stablehlo.dynamic_reshape %11524, %from_elements_4065 : (tensor, tensor<2xi64>) -> tensor + %dim_4066 = tensor.dim %11521, %c0 : tensor + %11555 = arith.index_cast %dim_4066 : index to i64 + %from_elements_4067 = tensor.from_elements %11555, %c1_i64 : tensor<2xi64> + %11556 = stablehlo.dynamic_reshape %11521, %from_elements_4067 : (tensor, tensor<2xi64>) -> tensor + %11557 = stablehlo.concatenate %11554, %11556, dim = 1 : (tensor, tensor) -> tensor + %11558 = "stablehlo.gather"(%11237, %11557) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11559 = shape.shape_of %11552 : tensor -> tensor<2xindex> + %11560 = shape.shape_of %11558 : tensor -> tensor<2xindex> + %11561 = shape.cstr_broadcastable %11559, %11560 : tensor<2xindex>, tensor<2xindex> + %11562 = shape.assuming %11561 -> (tensor) { + %19688 = shape.broadcast %11559, %11560 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11552, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11558, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11563 = shape.shape_of %11562 : tensor -> tensor<2xindex> + %11564 = stablehlo.dynamic_broadcast_in_dim %11562, %11563, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11565 = stablehlo.dynamic_broadcast_in_dim %213, %11563, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11566 = stablehlo.multiply %11564, %11565 : tensor + %dim_4068 = tensor.dim %11526, %c0 : tensor + %11567 = arith.index_cast %dim_4068 : index to i64 + %dim_4069 = tensor.dim %11562, %c0 : tensor + %11568 = arith.index_cast %dim_4069 : index to i64 + %11569 = arith.maxsi %11567, %11568 : i64 + %11570 = arith.index_cast %11569 : i64 to index + %from_elements_4070 = tensor.from_elements %11570, %c4096 : tensor<2xindex> + %11571 = stablehlo.dynamic_broadcast_in_dim %11526, %from_elements_4070, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4071 = tensor.dim %11571, %c0 : tensor + %11572 = arith.index_cast %dim_4071 : index to i64 + %from_elements_4072 = tensor.from_elements %11572, %c4096_i64 : tensor<2xi64> + %11573 = stablehlo.real_dynamic_slice %11566, %c_22, %from_elements_4072, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4073 = tensor.from_elements %11572, %c4096_i64, %c1_i64 : tensor<3xi64> + %11574 = stablehlo.dynamic_reshape %11571, %from_elements_4073 : (tensor, tensor<3xi64>) -> tensor + %11575 = stablehlo.dynamic_iota %from_elements_4073, dim = 1 : (tensor<3xi64>) -> tensor + %11576 = stablehlo.concatenate %11574, %11575, dim = 2 : (tensor, tensor) -> tensor + %11577 = "stablehlo.scatter"(%11514, %11576, %11573) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11578 = stablehlo.slice %11197 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11579 = stablehlo.reshape %11578 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11580 = stablehlo.custom_call @byteir.non_zero(%11579) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4074 = tensor.dim %11580, %c0 : tensor + %11581 = arith.index_cast %dim_4074 : index to i64 + %from_elements_4075 = tensor.from_elements %11581, %c1_i64 : tensor<2xi64> + %11582 = stablehlo.real_dynamic_slice %11580, %c_22, %from_elements_4075, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4076 = tensor.dim %11582, %c0 : tensor + %11583 = arith.index_cast %dim_4076 : index to i64 + %from_elements_4077 = tensor.from_elements %11583 : tensor<1xi64> + %11584 = stablehlo.dynamic_reshape %11582, %from_elements_4077 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4078 = tensor.from_elements %11581, %c2_i64 : tensor<2xi64> + %11585 = stablehlo.real_dynamic_slice %11580, %c_24, %from_elements_4078, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4079 = tensor.dim %11585, %c0 : tensor + %11586 = arith.index_cast %dim_4079 : index to i64 + %from_elements_4080 = tensor.from_elements %11586 : tensor<1xi64> + %11587 = stablehlo.dynamic_reshape %11585, %from_elements_4080 : (tensor, tensor<1xi64>) -> tensor + %dim_4081 = tensor.dim %11587, %c0 : tensor + %11588 = arith.index_cast %dim_4081 : index to i64 + %from_elements_4082 = tensor.from_elements %11588, %c1_i64 : tensor<2xi64> + %11589 = stablehlo.dynamic_reshape %11587, %from_elements_4082 : (tensor, tensor<2xi64>) -> tensor + %dim_4083 = tensor.dim %11589, %c0 : tensor + %11590 = arith.index_cast %dim_4083 : index to i64 + %from_elements_4084 = tensor.from_elements %c1_i64, %11590, %c4096_i64 : tensor<3xi64> + %11591 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4084, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4085 = tensor.dim %11591, %c1 : tensor<1x?x4096xi64> + %11592 = arith.index_cast %dim_4085 : index to i64 + %from_elements_4086 = tensor.from_elements %c1_i64, %11592, %c4096_i64, %c1_i64 : tensor<4xi64> + %11593 = stablehlo.dynamic_reshape %11591, %from_elements_4086 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11594 = stablehlo.dynamic_broadcast_in_dim %11589, %from_elements_4084, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4087 = tensor.dim %11594, %c1 : tensor<1x?x4096xi64> + %11595 = arith.index_cast %dim_4087 : index to i64 + %from_elements_4088 = tensor.from_elements %c1_i64, %11595, %c4096_i64, %c1_i64 : tensor<4xi64> + %11596 = stablehlo.dynamic_reshape %11594, %from_elements_4088 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11597 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4084, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4089 = tensor.dim %11597, %c1 : tensor<1x?x4096xi64> + %11598 = arith.index_cast %dim_4089 : index to i64 + %from_elements_4090 = tensor.from_elements %c1_i64, %11598, %c4096_i64, %c1_i64 : tensor<4xi64> + %11599 = stablehlo.dynamic_reshape %11597, %from_elements_4090 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11600 = stablehlo.concatenate %11593, %11596, %11599, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11601 = "stablehlo.gather"(%11208, %11600) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11602 = shape.shape_of %11601 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11603 = shape.num_elements %11602 : tensor<3xindex> -> index + %11604 = stablehlo.compute_reshape_shape %11603, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11605 = stablehlo.dynamic_reshape %11601, %11604 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11606 = stablehlo.dot %11605, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11607 = stablehlo.logistic %11606 : tensor + %11608 = shape.shape_of %11607 : tensor -> tensor<2xindex> + %11609 = shape.shape_of %11606 : tensor -> tensor<2xindex> + %11610 = shape.cstr_broadcastable %11608, %11609 : tensor<2xindex>, tensor<2xindex> + %11611 = shape.assuming %11610 -> (tensor) { + %19688 = shape.broadcast %11608, %11609 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11607, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11606, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11612 = shape.shape_of %11611 : tensor -> tensor<2xindex> + %11613 = shape.cstr_broadcastable %11612, %11609 : tensor<2xindex>, tensor<2xindex> + %11614 = shape.assuming %11613 -> (tensor) { + %19688 = shape.broadcast %11612, %11609 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11611, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11606, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11615 = stablehlo.dot %11614, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4091 = tensor.dim %11587, %c0 : tensor + %11616 = arith.index_cast %dim_4091 : index to i64 + %from_elements_4092 = tensor.from_elements %11616, %c1_i64 : tensor<2xi64> + %11617 = stablehlo.dynamic_reshape %11587, %from_elements_4092 : (tensor, tensor<2xi64>) -> tensor + %dim_4093 = tensor.dim %11584, %c0 : tensor + %11618 = arith.index_cast %dim_4093 : index to i64 + %from_elements_4094 = tensor.from_elements %11618, %c1_i64 : tensor<2xi64> + %11619 = stablehlo.dynamic_reshape %11584, %from_elements_4094 : (tensor, tensor<2xi64>) -> tensor + %11620 = stablehlo.concatenate %11617, %11619, dim = 1 : (tensor, tensor) -> tensor + %11621 = "stablehlo.gather"(%11237, %11620) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11622 = shape.shape_of %11615 : tensor -> tensor<2xindex> + %11623 = shape.shape_of %11621 : tensor -> tensor<2xindex> + %11624 = shape.cstr_broadcastable %11622, %11623 : tensor<2xindex>, tensor<2xindex> + %11625 = shape.assuming %11624 -> (tensor) { + %19688 = shape.broadcast %11622, %11623 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11615, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11621, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11626 = shape.shape_of %11625 : tensor -> tensor<2xindex> + %11627 = stablehlo.dynamic_broadcast_in_dim %11625, %11626, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11628 = stablehlo.dynamic_broadcast_in_dim %213, %11626, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11629 = stablehlo.multiply %11627, %11628 : tensor + %dim_4095 = tensor.dim %11589, %c0 : tensor + %11630 = arith.index_cast %dim_4095 : index to i64 + %dim_4096 = tensor.dim %11625, %c0 : tensor + %11631 = arith.index_cast %dim_4096 : index to i64 + %11632 = arith.maxsi %11630, %11631 : i64 + %11633 = arith.index_cast %11632 : i64 to index + %from_elements_4097 = tensor.from_elements %11633, %c4096 : tensor<2xindex> + %11634 = stablehlo.dynamic_broadcast_in_dim %11589, %from_elements_4097, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4098 = tensor.dim %11634, %c0 : tensor + %11635 = arith.index_cast %dim_4098 : index to i64 + %from_elements_4099 = tensor.from_elements %11635, %c4096_i64 : tensor<2xi64> + %11636 = stablehlo.real_dynamic_slice %11629, %c_22, %from_elements_4099, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4100 = tensor.from_elements %11635, %c4096_i64, %c1_i64 : tensor<3xi64> + %11637 = stablehlo.dynamic_reshape %11634, %from_elements_4100 : (tensor, tensor<3xi64>) -> tensor + %11638 = stablehlo.dynamic_iota %from_elements_4100, dim = 1 : (tensor<3xi64>) -> tensor + %11639 = stablehlo.concatenate %11637, %11638, dim = 2 : (tensor, tensor) -> tensor + %11640 = "stablehlo.scatter"(%11577, %11639, %11636) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11641 = stablehlo.slice %11197 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11642 = stablehlo.reshape %11641 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11643 = stablehlo.custom_call @byteir.non_zero(%11642) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4101 = tensor.dim %11643, %c0 : tensor + %11644 = arith.index_cast %dim_4101 : index to i64 + %from_elements_4102 = tensor.from_elements %11644, %c1_i64 : tensor<2xi64> + %11645 = stablehlo.real_dynamic_slice %11643, %c_22, %from_elements_4102, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4103 = tensor.dim %11645, %c0 : tensor + %11646 = arith.index_cast %dim_4103 : index to i64 + %from_elements_4104 = tensor.from_elements %11646 : tensor<1xi64> + %11647 = stablehlo.dynamic_reshape %11645, %from_elements_4104 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4105 = tensor.from_elements %11644, %c2_i64 : tensor<2xi64> + %11648 = stablehlo.real_dynamic_slice %11643, %c_24, %from_elements_4105, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4106 = tensor.dim %11648, %c0 : tensor + %11649 = arith.index_cast %dim_4106 : index to i64 + %from_elements_4107 = tensor.from_elements %11649 : tensor<1xi64> + %11650 = stablehlo.dynamic_reshape %11648, %from_elements_4107 : (tensor, tensor<1xi64>) -> tensor + %dim_4108 = tensor.dim %11650, %c0 : tensor + %11651 = arith.index_cast %dim_4108 : index to i64 + %from_elements_4109 = tensor.from_elements %11651, %c1_i64 : tensor<2xi64> + %11652 = stablehlo.dynamic_reshape %11650, %from_elements_4109 : (tensor, tensor<2xi64>) -> tensor + %dim_4110 = tensor.dim %11652, %c0 : tensor + %11653 = arith.index_cast %dim_4110 : index to i64 + %from_elements_4111 = tensor.from_elements %c1_i64, %11653, %c4096_i64 : tensor<3xi64> + %11654 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4111, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4112 = tensor.dim %11654, %c1 : tensor<1x?x4096xi64> + %11655 = arith.index_cast %dim_4112 : index to i64 + %from_elements_4113 = tensor.from_elements %c1_i64, %11655, %c4096_i64, %c1_i64 : tensor<4xi64> + %11656 = stablehlo.dynamic_reshape %11654, %from_elements_4113 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11657 = stablehlo.dynamic_broadcast_in_dim %11652, %from_elements_4111, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4114 = tensor.dim %11657, %c1 : tensor<1x?x4096xi64> + %11658 = arith.index_cast %dim_4114 : index to i64 + %from_elements_4115 = tensor.from_elements %c1_i64, %11658, %c4096_i64, %c1_i64 : tensor<4xi64> + %11659 = stablehlo.dynamic_reshape %11657, %from_elements_4115 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11660 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4111, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4116 = tensor.dim %11660, %c1 : tensor<1x?x4096xi64> + %11661 = arith.index_cast %dim_4116 : index to i64 + %from_elements_4117 = tensor.from_elements %c1_i64, %11661, %c4096_i64, %c1_i64 : tensor<4xi64> + %11662 = stablehlo.dynamic_reshape %11660, %from_elements_4117 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11663 = stablehlo.concatenate %11656, %11659, %11662, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11664 = "stablehlo.gather"(%11208, %11663) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11665 = shape.shape_of %11664 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11666 = shape.num_elements %11665 : tensor<3xindex> -> index + %11667 = stablehlo.compute_reshape_shape %11666, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11668 = stablehlo.dynamic_reshape %11664, %11667 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11669 = stablehlo.dot %11668, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11670 = stablehlo.logistic %11669 : tensor + %11671 = shape.shape_of %11670 : tensor -> tensor<2xindex> + %11672 = shape.shape_of %11669 : tensor -> tensor<2xindex> + %11673 = shape.cstr_broadcastable %11671, %11672 : tensor<2xindex>, tensor<2xindex> + %11674 = shape.assuming %11673 -> (tensor) { + %19688 = shape.broadcast %11671, %11672 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11670, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11669, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11675 = shape.shape_of %11674 : tensor -> tensor<2xindex> + %11676 = shape.cstr_broadcastable %11675, %11672 : tensor<2xindex>, tensor<2xindex> + %11677 = shape.assuming %11676 -> (tensor) { + %19688 = shape.broadcast %11675, %11672 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11674, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11669, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11678 = stablehlo.dot %11677, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4118 = tensor.dim %11650, %c0 : tensor + %11679 = arith.index_cast %dim_4118 : index to i64 + %from_elements_4119 = tensor.from_elements %11679, %c1_i64 : tensor<2xi64> + %11680 = stablehlo.dynamic_reshape %11650, %from_elements_4119 : (tensor, tensor<2xi64>) -> tensor + %dim_4120 = tensor.dim %11647, %c0 : tensor + %11681 = arith.index_cast %dim_4120 : index to i64 + %from_elements_4121 = tensor.from_elements %11681, %c1_i64 : tensor<2xi64> + %11682 = stablehlo.dynamic_reshape %11647, %from_elements_4121 : (tensor, tensor<2xi64>) -> tensor + %11683 = stablehlo.concatenate %11680, %11682, dim = 1 : (tensor, tensor) -> tensor + %11684 = "stablehlo.gather"(%11237, %11683) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11685 = shape.shape_of %11678 : tensor -> tensor<2xindex> + %11686 = shape.shape_of %11684 : tensor -> tensor<2xindex> + %11687 = shape.cstr_broadcastable %11685, %11686 : tensor<2xindex>, tensor<2xindex> + %11688 = shape.assuming %11687 -> (tensor) { + %19688 = shape.broadcast %11685, %11686 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11678, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11684, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11689 = shape.shape_of %11688 : tensor -> tensor<2xindex> + %11690 = stablehlo.dynamic_broadcast_in_dim %11688, %11689, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11691 = stablehlo.dynamic_broadcast_in_dim %213, %11689, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11692 = stablehlo.multiply %11690, %11691 : tensor + %dim_4122 = tensor.dim %11652, %c0 : tensor + %11693 = arith.index_cast %dim_4122 : index to i64 + %dim_4123 = tensor.dim %11688, %c0 : tensor + %11694 = arith.index_cast %dim_4123 : index to i64 + %11695 = arith.maxsi %11693, %11694 : i64 + %11696 = arith.index_cast %11695 : i64 to index + %from_elements_4124 = tensor.from_elements %11696, %c4096 : tensor<2xindex> + %11697 = stablehlo.dynamic_broadcast_in_dim %11652, %from_elements_4124, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4125 = tensor.dim %11697, %c0 : tensor + %11698 = arith.index_cast %dim_4125 : index to i64 + %from_elements_4126 = tensor.from_elements %11698, %c4096_i64 : tensor<2xi64> + %11699 = stablehlo.real_dynamic_slice %11692, %c_22, %from_elements_4126, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4127 = tensor.from_elements %11698, %c4096_i64, %c1_i64 : tensor<3xi64> + %11700 = stablehlo.dynamic_reshape %11697, %from_elements_4127 : (tensor, tensor<3xi64>) -> tensor + %11701 = stablehlo.dynamic_iota %from_elements_4127, dim = 1 : (tensor<3xi64>) -> tensor + %11702 = stablehlo.concatenate %11700, %11701, dim = 2 : (tensor, tensor) -> tensor + %11703 = "stablehlo.scatter"(%11640, %11702, %11699) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11704 = stablehlo.reshape %11703 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %11705 = stablehlo.add %11170, %11704 : tensor<3x1x4096xf32> + %11706 = stablehlo.broadcast_in_dim %11705, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %11707 = stablehlo.power %11706, %15 : tensor<3x1x4096xf32> + %11708 = stablehlo.reduce(%11707 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %11709 = stablehlo.reshape %11708 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %11710 = stablehlo.broadcast_in_dim %11709, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %11711 = stablehlo.divide %11710, %21 : tensor<3x1x1xf32> + %11712 = stablehlo.broadcast_in_dim %11711, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %11713 = stablehlo.add %11712, %25 : tensor<3x1x1xf32> + %11714 = stablehlo.rsqrt %11713 : tensor<3x1x1xf32> + %11715 = stablehlo.broadcast_in_dim %11714, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %11716 = stablehlo.multiply %11706, %11715 : tensor<3x1x4096xf32> + %11717 = stablehlo.broadcast_in_dim %11716, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %11718 = stablehlo.multiply %11717, %31 : tensor<3x1x4096xf32> + %11719 = stablehlo.reshape %11718 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %11720 = stablehlo.dot %11719, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %11721 = stablehlo.reshape %11720 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %11722 = stablehlo.dot %11719, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %11723 = stablehlo.reshape %11722 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %11724 = stablehlo.reshape %11721 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %11725 = stablehlo.transpose %11724, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %11726 = stablehlo.reshape %11723 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %11727 = stablehlo.transpose %11726, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %11728 = stablehlo.slice %arg38 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %11729 = stablehlo.slice %arg39 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %11730 = "stablehlo.gather"(%11728, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %11731 = stablehlo.reshape %11730 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %11732 = "stablehlo.gather"(%11729, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %11733 = stablehlo.reshape %11732 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %11734 = stablehlo.broadcast_in_dim %11725, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %11735 = stablehlo.broadcast_in_dim %11731, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %11736 = stablehlo.multiply %11734, %11735 : tensor<3x32x1x128xf32> + %11737 = stablehlo.slice %11725 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %11738 = stablehlo.slice %11725 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %11739 = stablehlo.negate %11738 : tensor<3x32x1x64xf32> + %11740 = stablehlo.concatenate %11739, %11737, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %11741 = stablehlo.broadcast_in_dim %11740, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %11742 = stablehlo.broadcast_in_dim %11733, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %11743 = stablehlo.multiply %11741, %11742 : tensor<3x32x1x128xf32> + %11744 = stablehlo.add %11736, %11743 : tensor<3x32x1x128xf32> + %11745 = stablehlo.broadcast_in_dim %11727, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %11746 = stablehlo.broadcast_in_dim %11731, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %11747 = stablehlo.multiply %11745, %11746 : tensor<3x8x1x128xf32> + %11748 = stablehlo.slice %11727 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %11749 = stablehlo.slice %11727 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %11750 = stablehlo.negate %11749 : tensor<3x8x1x64xf32> + %11751 = stablehlo.concatenate %11750, %11748, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %11752 = stablehlo.broadcast_in_dim %11751, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %11753 = stablehlo.broadcast_in_dim %11733, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %11754 = stablehlo.multiply %11752, %11753 : tensor<3x8x1x128xf32> + %11755 = stablehlo.add %11747, %11754 : tensor<3x8x1x128xf32> + %11756 = stablehlo.concatenate %arg103, %11755, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %11757 = stablehlo.concatenate %arg104, %11727, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %11758 = stablehlo.reshape %11756 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %11759 = stablehlo.broadcast_in_dim %11758, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %11760 = stablehlo.reshape %11759 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %11761 = stablehlo.reshape %11757 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %11762 = stablehlo.broadcast_in_dim %11761, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %11763 = stablehlo.reshape %11762 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %11764 = stablehlo.transpose %11760, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %11765 = stablehlo.reshape %11744 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %11766 = stablehlo.reshape %11764 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %11767 = stablehlo.broadcast_in_dim %11766, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %11768 = stablehlo.dot_general %11765, %11767, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %11769 = stablehlo.reshape %11768 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %11770 = stablehlo.broadcast_in_dim %11769, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %11771 = stablehlo.divide %11770, %89 : tensor<3x32x1x8xf32> + %11772 = stablehlo.custom_call @byteir.softmax(%11771) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %11773 = stablehlo.reshape %11772 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %11774 = stablehlo.reshape %11763 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %11775 = stablehlo.broadcast_in_dim %11774, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %11776 = stablehlo.dot_general %11773, %11775, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %11777 = stablehlo.reshape %11776 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %11778 = stablehlo.transpose %11777, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %11779 = stablehlo.reshape %11778 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %11780 = stablehlo.reshape %11779 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %11781 = stablehlo.dot %11780, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %11782 = stablehlo.reshape %11781 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %11783 = stablehlo.add %11705, %11782 : tensor<3x1x4096xf32> + %11784 = stablehlo.broadcast_in_dim %11783, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %11785 = stablehlo.power %11784, %15 : tensor<3x1x4096xf32> + %11786 = stablehlo.reduce(%11785 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %11787 = stablehlo.reshape %11786 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %11788 = stablehlo.broadcast_in_dim %11787, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %11789 = stablehlo.divide %11788, %21 : tensor<3x1x1xf32> + %11790 = stablehlo.broadcast_in_dim %11789, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %11791 = stablehlo.add %11790, %25 : tensor<3x1x1xf32> + %11792 = stablehlo.rsqrt %11791 : tensor<3x1x1xf32> + %11793 = stablehlo.broadcast_in_dim %11792, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %11794 = stablehlo.multiply %11784, %11793 : tensor<3x1x4096xf32> + %11795 = stablehlo.broadcast_in_dim %11794, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %11796 = stablehlo.multiply %11795, %31 : tensor<3x1x4096xf32> + %11797 = stablehlo.reshape %11796 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %11798 = stablehlo.dot %11797, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %11799 = stablehlo.custom_call @byteir.softmax(%11798) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %11800:2 = stablehlo.custom_call @byteir.top_k(%11799) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %11801 = stablehlo.reduce(%11800#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %11802 = stablehlo.reshape %11801 : (tensor<3xf32>) -> tensor<3x1xf32> + %11803 = stablehlo.broadcast_in_dim %11800#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %11804 = stablehlo.broadcast_in_dim %11802, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %11805 = stablehlo.divide %11803, %11804 : tensor<3x2xf32> + %11806 = stablehlo.reshape %11800#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %11807 = stablehlo.broadcast_in_dim %11806, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %11808 = stablehlo.compare EQ, %11807, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %11809 = stablehlo.convert %11808 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %11810 = stablehlo.transpose %11809, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %11811 = stablehlo.slice %11810 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11812 = stablehlo.reshape %11811 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11813 = stablehlo.custom_call @byteir.non_zero(%11812) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4128 = tensor.dim %11813, %c0 : tensor + %11814 = arith.index_cast %dim_4128 : index to i64 + %from_elements_4129 = tensor.from_elements %11814, %c1_i64 : tensor<2xi64> + %11815 = stablehlo.real_dynamic_slice %11813, %c_22, %from_elements_4129, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4130 = tensor.dim %11815, %c0 : tensor + %11816 = arith.index_cast %dim_4130 : index to i64 + %from_elements_4131 = tensor.from_elements %11816 : tensor<1xi64> + %11817 = stablehlo.dynamic_reshape %11815, %from_elements_4131 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4132 = tensor.from_elements %11814, %c2_i64 : tensor<2xi64> + %11818 = stablehlo.real_dynamic_slice %11813, %c_24, %from_elements_4132, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4133 = tensor.dim %11818, %c0 : tensor + %11819 = arith.index_cast %dim_4133 : index to i64 + %from_elements_4134 = tensor.from_elements %11819 : tensor<1xi64> + %11820 = stablehlo.dynamic_reshape %11818, %from_elements_4134 : (tensor, tensor<1xi64>) -> tensor + %11821 = stablehlo.reshape %11797 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_4135 = tensor.dim %11820, %c0 : tensor + %11822 = arith.index_cast %dim_4135 : index to i64 + %from_elements_4136 = tensor.from_elements %11822, %c1_i64 : tensor<2xi64> + %11823 = stablehlo.dynamic_reshape %11820, %from_elements_4136 : (tensor, tensor<2xi64>) -> tensor + %dim_4137 = tensor.dim %11823, %c0 : tensor + %11824 = arith.index_cast %dim_4137 : index to i64 + %from_elements_4138 = tensor.from_elements %c1_i64, %11824, %c4096_i64 : tensor<3xi64> + %11825 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4138, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4139 = tensor.dim %11825, %c1 : tensor<1x?x4096xi64> + %11826 = arith.index_cast %dim_4139 : index to i64 + %from_elements_4140 = tensor.from_elements %c1_i64, %11826, %c4096_i64, %c1_i64 : tensor<4xi64> + %11827 = stablehlo.dynamic_reshape %11825, %from_elements_4140 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11828 = stablehlo.dynamic_broadcast_in_dim %11823, %from_elements_4138, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4141 = tensor.dim %11828, %c1 : tensor<1x?x4096xi64> + %11829 = arith.index_cast %dim_4141 : index to i64 + %from_elements_4142 = tensor.from_elements %c1_i64, %11829, %c4096_i64, %c1_i64 : tensor<4xi64> + %11830 = stablehlo.dynamic_reshape %11828, %from_elements_4142 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11831 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4138, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4143 = tensor.dim %11831, %c1 : tensor<1x?x4096xi64> + %11832 = arith.index_cast %dim_4143 : index to i64 + %from_elements_4144 = tensor.from_elements %c1_i64, %11832, %c4096_i64, %c1_i64 : tensor<4xi64> + %11833 = stablehlo.dynamic_reshape %11831, %from_elements_4144 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11834 = stablehlo.concatenate %11827, %11830, %11833, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11835 = "stablehlo.gather"(%11821, %11834) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11836 = shape.shape_of %11835 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11837 = shape.num_elements %11836 : tensor<3xindex> -> index + %11838 = stablehlo.compute_reshape_shape %11837, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11839 = stablehlo.dynamic_reshape %11835, %11838 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11840 = stablehlo.dot %11839, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11841 = stablehlo.logistic %11840 : tensor + %11842 = shape.shape_of %11841 : tensor -> tensor<2xindex> + %11843 = shape.shape_of %11840 : tensor -> tensor<2xindex> + %11844 = shape.cstr_broadcastable %11842, %11843 : tensor<2xindex>, tensor<2xindex> + %11845 = shape.assuming %11844 -> (tensor) { + %19688 = shape.broadcast %11842, %11843 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11841, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11840, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11846 = shape.shape_of %11845 : tensor -> tensor<2xindex> + %11847 = shape.cstr_broadcastable %11846, %11843 : tensor<2xindex>, tensor<2xindex> + %11848 = shape.assuming %11847 -> (tensor) { + %19688 = shape.broadcast %11846, %11843 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11845, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11840, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11849 = stablehlo.dot %11848, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %11850 = stablehlo.reshape %11805 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_4145 = tensor.dim %11820, %c0 : tensor + %11851 = arith.index_cast %dim_4145 : index to i64 + %from_elements_4146 = tensor.from_elements %11851, %c1_i64 : tensor<2xi64> + %11852 = stablehlo.dynamic_reshape %11820, %from_elements_4146 : (tensor, tensor<2xi64>) -> tensor + %dim_4147 = tensor.dim %11817, %c0 : tensor + %11853 = arith.index_cast %dim_4147 : index to i64 + %from_elements_4148 = tensor.from_elements %11853, %c1_i64 : tensor<2xi64> + %11854 = stablehlo.dynamic_reshape %11817, %from_elements_4148 : (tensor, tensor<2xi64>) -> tensor + %11855 = stablehlo.concatenate %11852, %11854, dim = 1 : (tensor, tensor) -> tensor + %11856 = "stablehlo.gather"(%11850, %11855) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11857 = shape.shape_of %11849 : tensor -> tensor<2xindex> + %11858 = shape.shape_of %11856 : tensor -> tensor<2xindex> + %11859 = shape.cstr_broadcastable %11857, %11858 : tensor<2xindex>, tensor<2xindex> + %11860 = shape.assuming %11859 -> (tensor) { + %19688 = shape.broadcast %11857, %11858 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11849, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11856, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11861 = shape.shape_of %11860 : tensor -> tensor<2xindex> + %11862 = stablehlo.dynamic_broadcast_in_dim %11860, %11861, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11863 = stablehlo.dynamic_broadcast_in_dim %213, %11861, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11864 = stablehlo.multiply %11862, %11863 : tensor + %dim_4149 = tensor.dim %11823, %c0 : tensor + %11865 = arith.index_cast %dim_4149 : index to i64 + %dim_4150 = tensor.dim %11860, %c0 : tensor + %11866 = arith.index_cast %dim_4150 : index to i64 + %11867 = arith.maxsi %11865, %11866 : i64 + %11868 = arith.index_cast %11867 : i64 to index + %from_elements_4151 = tensor.from_elements %11868, %c4096 : tensor<2xindex> + %11869 = stablehlo.dynamic_broadcast_in_dim %11823, %from_elements_4151, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4152 = tensor.dim %11869, %c0 : tensor + %11870 = arith.index_cast %dim_4152 : index to i64 + %from_elements_4153 = tensor.from_elements %11870, %c4096_i64 : tensor<2xi64> + %11871 = stablehlo.real_dynamic_slice %11864, %c_22, %from_elements_4153, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4154 = tensor.from_elements %11870, %c4096_i64, %c1_i64 : tensor<3xi64> + %11872 = stablehlo.dynamic_reshape %11869, %from_elements_4154 : (tensor, tensor<3xi64>) -> tensor + %11873 = stablehlo.dynamic_iota %from_elements_4154, dim = 1 : (tensor<3xi64>) -> tensor + %11874 = stablehlo.concatenate %11872, %11873, dim = 2 : (tensor, tensor) -> tensor + %11875 = "stablehlo.scatter"(%cst_2, %11874, %11871) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11876 = stablehlo.slice %11810 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11877 = stablehlo.reshape %11876 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11878 = stablehlo.custom_call @byteir.non_zero(%11877) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4155 = tensor.dim %11878, %c0 : tensor + %11879 = arith.index_cast %dim_4155 : index to i64 + %from_elements_4156 = tensor.from_elements %11879, %c1_i64 : tensor<2xi64> + %11880 = stablehlo.real_dynamic_slice %11878, %c_22, %from_elements_4156, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4157 = tensor.dim %11880, %c0 : tensor + %11881 = arith.index_cast %dim_4157 : index to i64 + %from_elements_4158 = tensor.from_elements %11881 : tensor<1xi64> + %11882 = stablehlo.dynamic_reshape %11880, %from_elements_4158 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4159 = tensor.from_elements %11879, %c2_i64 : tensor<2xi64> + %11883 = stablehlo.real_dynamic_slice %11878, %c_24, %from_elements_4159, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4160 = tensor.dim %11883, %c0 : tensor + %11884 = arith.index_cast %dim_4160 : index to i64 + %from_elements_4161 = tensor.from_elements %11884 : tensor<1xi64> + %11885 = stablehlo.dynamic_reshape %11883, %from_elements_4161 : (tensor, tensor<1xi64>) -> tensor + %dim_4162 = tensor.dim %11885, %c0 : tensor + %11886 = arith.index_cast %dim_4162 : index to i64 + %from_elements_4163 = tensor.from_elements %11886, %c1_i64 : tensor<2xi64> + %11887 = stablehlo.dynamic_reshape %11885, %from_elements_4163 : (tensor, tensor<2xi64>) -> tensor + %dim_4164 = tensor.dim %11887, %c0 : tensor + %11888 = arith.index_cast %dim_4164 : index to i64 + %from_elements_4165 = tensor.from_elements %c1_i64, %11888, %c4096_i64 : tensor<3xi64> + %11889 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4165, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4166 = tensor.dim %11889, %c1 : tensor<1x?x4096xi64> + %11890 = arith.index_cast %dim_4166 : index to i64 + %from_elements_4167 = tensor.from_elements %c1_i64, %11890, %c4096_i64, %c1_i64 : tensor<4xi64> + %11891 = stablehlo.dynamic_reshape %11889, %from_elements_4167 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11892 = stablehlo.dynamic_broadcast_in_dim %11887, %from_elements_4165, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4168 = tensor.dim %11892, %c1 : tensor<1x?x4096xi64> + %11893 = arith.index_cast %dim_4168 : index to i64 + %from_elements_4169 = tensor.from_elements %c1_i64, %11893, %c4096_i64, %c1_i64 : tensor<4xi64> + %11894 = stablehlo.dynamic_reshape %11892, %from_elements_4169 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11895 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4165, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4170 = tensor.dim %11895, %c1 : tensor<1x?x4096xi64> + %11896 = arith.index_cast %dim_4170 : index to i64 + %from_elements_4171 = tensor.from_elements %c1_i64, %11896, %c4096_i64, %c1_i64 : tensor<4xi64> + %11897 = stablehlo.dynamic_reshape %11895, %from_elements_4171 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11898 = stablehlo.concatenate %11891, %11894, %11897, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11899 = "stablehlo.gather"(%11821, %11898) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11900 = shape.shape_of %11899 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11901 = shape.num_elements %11900 : tensor<3xindex> -> index + %11902 = stablehlo.compute_reshape_shape %11901, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11903 = stablehlo.dynamic_reshape %11899, %11902 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11904 = stablehlo.dot %11903, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11905 = stablehlo.logistic %11904 : tensor + %11906 = shape.shape_of %11905 : tensor -> tensor<2xindex> + %11907 = shape.shape_of %11904 : tensor -> tensor<2xindex> + %11908 = shape.cstr_broadcastable %11906, %11907 : tensor<2xindex>, tensor<2xindex> + %11909 = shape.assuming %11908 -> (tensor) { + %19688 = shape.broadcast %11906, %11907 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11905, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11904, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11910 = shape.shape_of %11909 : tensor -> tensor<2xindex> + %11911 = shape.cstr_broadcastable %11910, %11907 : tensor<2xindex>, tensor<2xindex> + %11912 = shape.assuming %11911 -> (tensor) { + %19688 = shape.broadcast %11910, %11907 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11909, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11904, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11913 = stablehlo.dot %11912, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4172 = tensor.dim %11885, %c0 : tensor + %11914 = arith.index_cast %dim_4172 : index to i64 + %from_elements_4173 = tensor.from_elements %11914, %c1_i64 : tensor<2xi64> + %11915 = stablehlo.dynamic_reshape %11885, %from_elements_4173 : (tensor, tensor<2xi64>) -> tensor + %dim_4174 = tensor.dim %11882, %c0 : tensor + %11916 = arith.index_cast %dim_4174 : index to i64 + %from_elements_4175 = tensor.from_elements %11916, %c1_i64 : tensor<2xi64> + %11917 = stablehlo.dynamic_reshape %11882, %from_elements_4175 : (tensor, tensor<2xi64>) -> tensor + %11918 = stablehlo.concatenate %11915, %11917, dim = 1 : (tensor, tensor) -> tensor + %11919 = "stablehlo.gather"(%11850, %11918) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11920 = shape.shape_of %11913 : tensor -> tensor<2xindex> + %11921 = shape.shape_of %11919 : tensor -> tensor<2xindex> + %11922 = shape.cstr_broadcastable %11920, %11921 : tensor<2xindex>, tensor<2xindex> + %11923 = shape.assuming %11922 -> (tensor) { + %19688 = shape.broadcast %11920, %11921 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11913, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11919, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11924 = shape.shape_of %11923 : tensor -> tensor<2xindex> + %11925 = stablehlo.dynamic_broadcast_in_dim %11923, %11924, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11926 = stablehlo.dynamic_broadcast_in_dim %213, %11924, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11927 = stablehlo.multiply %11925, %11926 : tensor + %dim_4176 = tensor.dim %11887, %c0 : tensor + %11928 = arith.index_cast %dim_4176 : index to i64 + %dim_4177 = tensor.dim %11923, %c0 : tensor + %11929 = arith.index_cast %dim_4177 : index to i64 + %11930 = arith.maxsi %11928, %11929 : i64 + %11931 = arith.index_cast %11930 : i64 to index + %from_elements_4178 = tensor.from_elements %11931, %c4096 : tensor<2xindex> + %11932 = stablehlo.dynamic_broadcast_in_dim %11887, %from_elements_4178, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4179 = tensor.dim %11932, %c0 : tensor + %11933 = arith.index_cast %dim_4179 : index to i64 + %from_elements_4180 = tensor.from_elements %11933, %c4096_i64 : tensor<2xi64> + %11934 = stablehlo.real_dynamic_slice %11927, %c_22, %from_elements_4180, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4181 = tensor.from_elements %11933, %c4096_i64, %c1_i64 : tensor<3xi64> + %11935 = stablehlo.dynamic_reshape %11932, %from_elements_4181 : (tensor, tensor<3xi64>) -> tensor + %11936 = stablehlo.dynamic_iota %from_elements_4181, dim = 1 : (tensor<3xi64>) -> tensor + %11937 = stablehlo.concatenate %11935, %11936, dim = 2 : (tensor, tensor) -> tensor + %11938 = "stablehlo.scatter"(%11875, %11937, %11934) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %11939 = stablehlo.slice %11810 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %11940 = stablehlo.reshape %11939 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %11941 = stablehlo.custom_call @byteir.non_zero(%11940) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4182 = tensor.dim %11941, %c0 : tensor + %11942 = arith.index_cast %dim_4182 : index to i64 + %from_elements_4183 = tensor.from_elements %11942, %c1_i64 : tensor<2xi64> + %11943 = stablehlo.real_dynamic_slice %11941, %c_22, %from_elements_4183, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4184 = tensor.dim %11943, %c0 : tensor + %11944 = arith.index_cast %dim_4184 : index to i64 + %from_elements_4185 = tensor.from_elements %11944 : tensor<1xi64> + %11945 = stablehlo.dynamic_reshape %11943, %from_elements_4185 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4186 = tensor.from_elements %11942, %c2_i64 : tensor<2xi64> + %11946 = stablehlo.real_dynamic_slice %11941, %c_24, %from_elements_4186, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4187 = tensor.dim %11946, %c0 : tensor + %11947 = arith.index_cast %dim_4187 : index to i64 + %from_elements_4188 = tensor.from_elements %11947 : tensor<1xi64> + %11948 = stablehlo.dynamic_reshape %11946, %from_elements_4188 : (tensor, tensor<1xi64>) -> tensor + %dim_4189 = tensor.dim %11948, %c0 : tensor + %11949 = arith.index_cast %dim_4189 : index to i64 + %from_elements_4190 = tensor.from_elements %11949, %c1_i64 : tensor<2xi64> + %11950 = stablehlo.dynamic_reshape %11948, %from_elements_4190 : (tensor, tensor<2xi64>) -> tensor + %dim_4191 = tensor.dim %11950, %c0 : tensor + %11951 = arith.index_cast %dim_4191 : index to i64 + %from_elements_4192 = tensor.from_elements %c1_i64, %11951, %c4096_i64 : tensor<3xi64> + %11952 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4192, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4193 = tensor.dim %11952, %c1 : tensor<1x?x4096xi64> + %11953 = arith.index_cast %dim_4193 : index to i64 + %from_elements_4194 = tensor.from_elements %c1_i64, %11953, %c4096_i64, %c1_i64 : tensor<4xi64> + %11954 = stablehlo.dynamic_reshape %11952, %from_elements_4194 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11955 = stablehlo.dynamic_broadcast_in_dim %11950, %from_elements_4192, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4195 = tensor.dim %11955, %c1 : tensor<1x?x4096xi64> + %11956 = arith.index_cast %dim_4195 : index to i64 + %from_elements_4196 = tensor.from_elements %c1_i64, %11956, %c4096_i64, %c1_i64 : tensor<4xi64> + %11957 = stablehlo.dynamic_reshape %11955, %from_elements_4196 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11958 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4192, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4197 = tensor.dim %11958, %c1 : tensor<1x?x4096xi64> + %11959 = arith.index_cast %dim_4197 : index to i64 + %from_elements_4198 = tensor.from_elements %c1_i64, %11959, %c4096_i64, %c1_i64 : tensor<4xi64> + %11960 = stablehlo.dynamic_reshape %11958, %from_elements_4198 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %11961 = stablehlo.concatenate %11954, %11957, %11960, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %11962 = "stablehlo.gather"(%11821, %11961) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %11963 = shape.shape_of %11962 : tensor<1x?x4096xf32> -> tensor<3xindex> + %11964 = shape.num_elements %11963 : tensor<3xindex> -> index + %11965 = stablehlo.compute_reshape_shape %11964, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %11966 = stablehlo.dynamic_reshape %11962, %11965 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %11967 = stablehlo.dot %11966, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %11968 = stablehlo.logistic %11967 : tensor + %11969 = shape.shape_of %11968 : tensor -> tensor<2xindex> + %11970 = shape.shape_of %11967 : tensor -> tensor<2xindex> + %11971 = shape.cstr_broadcastable %11969, %11970 : tensor<2xindex>, tensor<2xindex> + %11972 = shape.assuming %11971 -> (tensor) { + %19688 = shape.broadcast %11969, %11970 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11968, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11967, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11973 = shape.shape_of %11972 : tensor -> tensor<2xindex> + %11974 = shape.cstr_broadcastable %11973, %11970 : tensor<2xindex>, tensor<2xindex> + %11975 = shape.assuming %11974 -> (tensor) { + %19688 = shape.broadcast %11973, %11970 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11972, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11967, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11976 = stablehlo.dot %11975, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4199 = tensor.dim %11948, %c0 : tensor + %11977 = arith.index_cast %dim_4199 : index to i64 + %from_elements_4200 = tensor.from_elements %11977, %c1_i64 : tensor<2xi64> + %11978 = stablehlo.dynamic_reshape %11948, %from_elements_4200 : (tensor, tensor<2xi64>) -> tensor + %dim_4201 = tensor.dim %11945, %c0 : tensor + %11979 = arith.index_cast %dim_4201 : index to i64 + %from_elements_4202 = tensor.from_elements %11979, %c1_i64 : tensor<2xi64> + %11980 = stablehlo.dynamic_reshape %11945, %from_elements_4202 : (tensor, tensor<2xi64>) -> tensor + %11981 = stablehlo.concatenate %11978, %11980, dim = 1 : (tensor, tensor) -> tensor + %11982 = "stablehlo.gather"(%11850, %11981) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %11983 = shape.shape_of %11976 : tensor -> tensor<2xindex> + %11984 = shape.shape_of %11982 : tensor -> tensor<2xindex> + %11985 = shape.cstr_broadcastable %11983, %11984 : tensor<2xindex>, tensor<2xindex> + %11986 = shape.assuming %11985 -> (tensor) { + %19688 = shape.broadcast %11983, %11984 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %11976, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %11982, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %11987 = shape.shape_of %11986 : tensor -> tensor<2xindex> + %11988 = stablehlo.dynamic_broadcast_in_dim %11986, %11987, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %11989 = stablehlo.dynamic_broadcast_in_dim %213, %11987, dims = [] : (tensor, tensor<2xindex>) -> tensor + %11990 = stablehlo.multiply %11988, %11989 : tensor + %dim_4203 = tensor.dim %11950, %c0 : tensor + %11991 = arith.index_cast %dim_4203 : index to i64 + %dim_4204 = tensor.dim %11986, %c0 : tensor + %11992 = arith.index_cast %dim_4204 : index to i64 + %11993 = arith.maxsi %11991, %11992 : i64 + %11994 = arith.index_cast %11993 : i64 to index + %from_elements_4205 = tensor.from_elements %11994, %c4096 : tensor<2xindex> + %11995 = stablehlo.dynamic_broadcast_in_dim %11950, %from_elements_4205, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4206 = tensor.dim %11995, %c0 : tensor + %11996 = arith.index_cast %dim_4206 : index to i64 + %from_elements_4207 = tensor.from_elements %11996, %c4096_i64 : tensor<2xi64> + %11997 = stablehlo.real_dynamic_slice %11990, %c_22, %from_elements_4207, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4208 = tensor.from_elements %11996, %c4096_i64, %c1_i64 : tensor<3xi64> + %11998 = stablehlo.dynamic_reshape %11995, %from_elements_4208 : (tensor, tensor<3xi64>) -> tensor + %11999 = stablehlo.dynamic_iota %from_elements_4208, dim = 1 : (tensor<3xi64>) -> tensor + %12000 = stablehlo.concatenate %11998, %11999, dim = 2 : (tensor, tensor) -> tensor + %12001 = "stablehlo.scatter"(%11938, %12000, %11997) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12002 = stablehlo.slice %11810 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12003 = stablehlo.reshape %12002 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12004 = stablehlo.custom_call @byteir.non_zero(%12003) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4209 = tensor.dim %12004, %c0 : tensor + %12005 = arith.index_cast %dim_4209 : index to i64 + %from_elements_4210 = tensor.from_elements %12005, %c1_i64 : tensor<2xi64> + %12006 = stablehlo.real_dynamic_slice %12004, %c_22, %from_elements_4210, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4211 = tensor.dim %12006, %c0 : tensor + %12007 = arith.index_cast %dim_4211 : index to i64 + %from_elements_4212 = tensor.from_elements %12007 : tensor<1xi64> + %12008 = stablehlo.dynamic_reshape %12006, %from_elements_4212 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4213 = tensor.from_elements %12005, %c2_i64 : tensor<2xi64> + %12009 = stablehlo.real_dynamic_slice %12004, %c_24, %from_elements_4213, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4214 = tensor.dim %12009, %c0 : tensor + %12010 = arith.index_cast %dim_4214 : index to i64 + %from_elements_4215 = tensor.from_elements %12010 : tensor<1xi64> + %12011 = stablehlo.dynamic_reshape %12009, %from_elements_4215 : (tensor, tensor<1xi64>) -> tensor + %dim_4216 = tensor.dim %12011, %c0 : tensor + %12012 = arith.index_cast %dim_4216 : index to i64 + %from_elements_4217 = tensor.from_elements %12012, %c1_i64 : tensor<2xi64> + %12013 = stablehlo.dynamic_reshape %12011, %from_elements_4217 : (tensor, tensor<2xi64>) -> tensor + %dim_4218 = tensor.dim %12013, %c0 : tensor + %12014 = arith.index_cast %dim_4218 : index to i64 + %from_elements_4219 = tensor.from_elements %c1_i64, %12014, %c4096_i64 : tensor<3xi64> + %12015 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4219, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4220 = tensor.dim %12015, %c1 : tensor<1x?x4096xi64> + %12016 = arith.index_cast %dim_4220 : index to i64 + %from_elements_4221 = tensor.from_elements %c1_i64, %12016, %c4096_i64, %c1_i64 : tensor<4xi64> + %12017 = stablehlo.dynamic_reshape %12015, %from_elements_4221 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12018 = stablehlo.dynamic_broadcast_in_dim %12013, %from_elements_4219, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4222 = tensor.dim %12018, %c1 : tensor<1x?x4096xi64> + %12019 = arith.index_cast %dim_4222 : index to i64 + %from_elements_4223 = tensor.from_elements %c1_i64, %12019, %c4096_i64, %c1_i64 : tensor<4xi64> + %12020 = stablehlo.dynamic_reshape %12018, %from_elements_4223 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12021 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4219, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4224 = tensor.dim %12021, %c1 : tensor<1x?x4096xi64> + %12022 = arith.index_cast %dim_4224 : index to i64 + %from_elements_4225 = tensor.from_elements %c1_i64, %12022, %c4096_i64, %c1_i64 : tensor<4xi64> + %12023 = stablehlo.dynamic_reshape %12021, %from_elements_4225 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12024 = stablehlo.concatenate %12017, %12020, %12023, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12025 = "stablehlo.gather"(%11821, %12024) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12026 = shape.shape_of %12025 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12027 = shape.num_elements %12026 : tensor<3xindex> -> index + %12028 = stablehlo.compute_reshape_shape %12027, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12029 = stablehlo.dynamic_reshape %12025, %12028 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12030 = stablehlo.dot %12029, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12031 = stablehlo.logistic %12030 : tensor + %12032 = shape.shape_of %12031 : tensor -> tensor<2xindex> + %12033 = shape.shape_of %12030 : tensor -> tensor<2xindex> + %12034 = shape.cstr_broadcastable %12032, %12033 : tensor<2xindex>, tensor<2xindex> + %12035 = shape.assuming %12034 -> (tensor) { + %19688 = shape.broadcast %12032, %12033 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12031, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12030, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12036 = shape.shape_of %12035 : tensor -> tensor<2xindex> + %12037 = shape.cstr_broadcastable %12036, %12033 : tensor<2xindex>, tensor<2xindex> + %12038 = shape.assuming %12037 -> (tensor) { + %19688 = shape.broadcast %12036, %12033 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12035, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12030, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12039 = stablehlo.dot %12038, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4226 = tensor.dim %12011, %c0 : tensor + %12040 = arith.index_cast %dim_4226 : index to i64 + %from_elements_4227 = tensor.from_elements %12040, %c1_i64 : tensor<2xi64> + %12041 = stablehlo.dynamic_reshape %12011, %from_elements_4227 : (tensor, tensor<2xi64>) -> tensor + %dim_4228 = tensor.dim %12008, %c0 : tensor + %12042 = arith.index_cast %dim_4228 : index to i64 + %from_elements_4229 = tensor.from_elements %12042, %c1_i64 : tensor<2xi64> + %12043 = stablehlo.dynamic_reshape %12008, %from_elements_4229 : (tensor, tensor<2xi64>) -> tensor + %12044 = stablehlo.concatenate %12041, %12043, dim = 1 : (tensor, tensor) -> tensor + %12045 = "stablehlo.gather"(%11850, %12044) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12046 = shape.shape_of %12039 : tensor -> tensor<2xindex> + %12047 = shape.shape_of %12045 : tensor -> tensor<2xindex> + %12048 = shape.cstr_broadcastable %12046, %12047 : tensor<2xindex>, tensor<2xindex> + %12049 = shape.assuming %12048 -> (tensor) { + %19688 = shape.broadcast %12046, %12047 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12039, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12045, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12050 = shape.shape_of %12049 : tensor -> tensor<2xindex> + %12051 = stablehlo.dynamic_broadcast_in_dim %12049, %12050, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12052 = stablehlo.dynamic_broadcast_in_dim %213, %12050, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12053 = stablehlo.multiply %12051, %12052 : tensor + %dim_4230 = tensor.dim %12013, %c0 : tensor + %12054 = arith.index_cast %dim_4230 : index to i64 + %dim_4231 = tensor.dim %12049, %c0 : tensor + %12055 = arith.index_cast %dim_4231 : index to i64 + %12056 = arith.maxsi %12054, %12055 : i64 + %12057 = arith.index_cast %12056 : i64 to index + %from_elements_4232 = tensor.from_elements %12057, %c4096 : tensor<2xindex> + %12058 = stablehlo.dynamic_broadcast_in_dim %12013, %from_elements_4232, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4233 = tensor.dim %12058, %c0 : tensor + %12059 = arith.index_cast %dim_4233 : index to i64 + %from_elements_4234 = tensor.from_elements %12059, %c4096_i64 : tensor<2xi64> + %12060 = stablehlo.real_dynamic_slice %12053, %c_22, %from_elements_4234, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4235 = tensor.from_elements %12059, %c4096_i64, %c1_i64 : tensor<3xi64> + %12061 = stablehlo.dynamic_reshape %12058, %from_elements_4235 : (tensor, tensor<3xi64>) -> tensor + %12062 = stablehlo.dynamic_iota %from_elements_4235, dim = 1 : (tensor<3xi64>) -> tensor + %12063 = stablehlo.concatenate %12061, %12062, dim = 2 : (tensor, tensor) -> tensor + %12064 = "stablehlo.scatter"(%12001, %12063, %12060) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12065 = stablehlo.slice %11810 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12066 = stablehlo.reshape %12065 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12067 = stablehlo.custom_call @byteir.non_zero(%12066) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4236 = tensor.dim %12067, %c0 : tensor + %12068 = arith.index_cast %dim_4236 : index to i64 + %from_elements_4237 = tensor.from_elements %12068, %c1_i64 : tensor<2xi64> + %12069 = stablehlo.real_dynamic_slice %12067, %c_22, %from_elements_4237, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4238 = tensor.dim %12069, %c0 : tensor + %12070 = arith.index_cast %dim_4238 : index to i64 + %from_elements_4239 = tensor.from_elements %12070 : tensor<1xi64> + %12071 = stablehlo.dynamic_reshape %12069, %from_elements_4239 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4240 = tensor.from_elements %12068, %c2_i64 : tensor<2xi64> + %12072 = stablehlo.real_dynamic_slice %12067, %c_24, %from_elements_4240, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4241 = tensor.dim %12072, %c0 : tensor + %12073 = arith.index_cast %dim_4241 : index to i64 + %from_elements_4242 = tensor.from_elements %12073 : tensor<1xi64> + %12074 = stablehlo.dynamic_reshape %12072, %from_elements_4242 : (tensor, tensor<1xi64>) -> tensor + %dim_4243 = tensor.dim %12074, %c0 : tensor + %12075 = arith.index_cast %dim_4243 : index to i64 + %from_elements_4244 = tensor.from_elements %12075, %c1_i64 : tensor<2xi64> + %12076 = stablehlo.dynamic_reshape %12074, %from_elements_4244 : (tensor, tensor<2xi64>) -> tensor + %dim_4245 = tensor.dim %12076, %c0 : tensor + %12077 = arith.index_cast %dim_4245 : index to i64 + %from_elements_4246 = tensor.from_elements %c1_i64, %12077, %c4096_i64 : tensor<3xi64> + %12078 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4246, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4247 = tensor.dim %12078, %c1 : tensor<1x?x4096xi64> + %12079 = arith.index_cast %dim_4247 : index to i64 + %from_elements_4248 = tensor.from_elements %c1_i64, %12079, %c4096_i64, %c1_i64 : tensor<4xi64> + %12080 = stablehlo.dynamic_reshape %12078, %from_elements_4248 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12081 = stablehlo.dynamic_broadcast_in_dim %12076, %from_elements_4246, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4249 = tensor.dim %12081, %c1 : tensor<1x?x4096xi64> + %12082 = arith.index_cast %dim_4249 : index to i64 + %from_elements_4250 = tensor.from_elements %c1_i64, %12082, %c4096_i64, %c1_i64 : tensor<4xi64> + %12083 = stablehlo.dynamic_reshape %12081, %from_elements_4250 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12084 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4246, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4251 = tensor.dim %12084, %c1 : tensor<1x?x4096xi64> + %12085 = arith.index_cast %dim_4251 : index to i64 + %from_elements_4252 = tensor.from_elements %c1_i64, %12085, %c4096_i64, %c1_i64 : tensor<4xi64> + %12086 = stablehlo.dynamic_reshape %12084, %from_elements_4252 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12087 = stablehlo.concatenate %12080, %12083, %12086, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12088 = "stablehlo.gather"(%11821, %12087) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12089 = shape.shape_of %12088 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12090 = shape.num_elements %12089 : tensor<3xindex> -> index + %12091 = stablehlo.compute_reshape_shape %12090, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12092 = stablehlo.dynamic_reshape %12088, %12091 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12093 = stablehlo.dot %12092, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12094 = stablehlo.logistic %12093 : tensor + %12095 = shape.shape_of %12094 : tensor -> tensor<2xindex> + %12096 = shape.shape_of %12093 : tensor -> tensor<2xindex> + %12097 = shape.cstr_broadcastable %12095, %12096 : tensor<2xindex>, tensor<2xindex> + %12098 = shape.assuming %12097 -> (tensor) { + %19688 = shape.broadcast %12095, %12096 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12094, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12093, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12099 = shape.shape_of %12098 : tensor -> tensor<2xindex> + %12100 = shape.cstr_broadcastable %12099, %12096 : tensor<2xindex>, tensor<2xindex> + %12101 = shape.assuming %12100 -> (tensor) { + %19688 = shape.broadcast %12099, %12096 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12098, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12093, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12102 = stablehlo.dot %12101, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4253 = tensor.dim %12074, %c0 : tensor + %12103 = arith.index_cast %dim_4253 : index to i64 + %from_elements_4254 = tensor.from_elements %12103, %c1_i64 : tensor<2xi64> + %12104 = stablehlo.dynamic_reshape %12074, %from_elements_4254 : (tensor, tensor<2xi64>) -> tensor + %dim_4255 = tensor.dim %12071, %c0 : tensor + %12105 = arith.index_cast %dim_4255 : index to i64 + %from_elements_4256 = tensor.from_elements %12105, %c1_i64 : tensor<2xi64> + %12106 = stablehlo.dynamic_reshape %12071, %from_elements_4256 : (tensor, tensor<2xi64>) -> tensor + %12107 = stablehlo.concatenate %12104, %12106, dim = 1 : (tensor, tensor) -> tensor + %12108 = "stablehlo.gather"(%11850, %12107) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12109 = shape.shape_of %12102 : tensor -> tensor<2xindex> + %12110 = shape.shape_of %12108 : tensor -> tensor<2xindex> + %12111 = shape.cstr_broadcastable %12109, %12110 : tensor<2xindex>, tensor<2xindex> + %12112 = shape.assuming %12111 -> (tensor) { + %19688 = shape.broadcast %12109, %12110 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12102, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12108, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12113 = shape.shape_of %12112 : tensor -> tensor<2xindex> + %12114 = stablehlo.dynamic_broadcast_in_dim %12112, %12113, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12115 = stablehlo.dynamic_broadcast_in_dim %213, %12113, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12116 = stablehlo.multiply %12114, %12115 : tensor + %dim_4257 = tensor.dim %12076, %c0 : tensor + %12117 = arith.index_cast %dim_4257 : index to i64 + %dim_4258 = tensor.dim %12112, %c0 : tensor + %12118 = arith.index_cast %dim_4258 : index to i64 + %12119 = arith.maxsi %12117, %12118 : i64 + %12120 = arith.index_cast %12119 : i64 to index + %from_elements_4259 = tensor.from_elements %12120, %c4096 : tensor<2xindex> + %12121 = stablehlo.dynamic_broadcast_in_dim %12076, %from_elements_4259, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4260 = tensor.dim %12121, %c0 : tensor + %12122 = arith.index_cast %dim_4260 : index to i64 + %from_elements_4261 = tensor.from_elements %12122, %c4096_i64 : tensor<2xi64> + %12123 = stablehlo.real_dynamic_slice %12116, %c_22, %from_elements_4261, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4262 = tensor.from_elements %12122, %c4096_i64, %c1_i64 : tensor<3xi64> + %12124 = stablehlo.dynamic_reshape %12121, %from_elements_4262 : (tensor, tensor<3xi64>) -> tensor + %12125 = stablehlo.dynamic_iota %from_elements_4262, dim = 1 : (tensor<3xi64>) -> tensor + %12126 = stablehlo.concatenate %12124, %12125, dim = 2 : (tensor, tensor) -> tensor + %12127 = "stablehlo.scatter"(%12064, %12126, %12123) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12128 = stablehlo.slice %11810 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12129 = stablehlo.reshape %12128 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12130 = stablehlo.custom_call @byteir.non_zero(%12129) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4263 = tensor.dim %12130, %c0 : tensor + %12131 = arith.index_cast %dim_4263 : index to i64 + %from_elements_4264 = tensor.from_elements %12131, %c1_i64 : tensor<2xi64> + %12132 = stablehlo.real_dynamic_slice %12130, %c_22, %from_elements_4264, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4265 = tensor.dim %12132, %c0 : tensor + %12133 = arith.index_cast %dim_4265 : index to i64 + %from_elements_4266 = tensor.from_elements %12133 : tensor<1xi64> + %12134 = stablehlo.dynamic_reshape %12132, %from_elements_4266 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4267 = tensor.from_elements %12131, %c2_i64 : tensor<2xi64> + %12135 = stablehlo.real_dynamic_slice %12130, %c_24, %from_elements_4267, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4268 = tensor.dim %12135, %c0 : tensor + %12136 = arith.index_cast %dim_4268 : index to i64 + %from_elements_4269 = tensor.from_elements %12136 : tensor<1xi64> + %12137 = stablehlo.dynamic_reshape %12135, %from_elements_4269 : (tensor, tensor<1xi64>) -> tensor + %dim_4270 = tensor.dim %12137, %c0 : tensor + %12138 = arith.index_cast %dim_4270 : index to i64 + %from_elements_4271 = tensor.from_elements %12138, %c1_i64 : tensor<2xi64> + %12139 = stablehlo.dynamic_reshape %12137, %from_elements_4271 : (tensor, tensor<2xi64>) -> tensor + %dim_4272 = tensor.dim %12139, %c0 : tensor + %12140 = arith.index_cast %dim_4272 : index to i64 + %from_elements_4273 = tensor.from_elements %c1_i64, %12140, %c4096_i64 : tensor<3xi64> + %12141 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4273, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4274 = tensor.dim %12141, %c1 : tensor<1x?x4096xi64> + %12142 = arith.index_cast %dim_4274 : index to i64 + %from_elements_4275 = tensor.from_elements %c1_i64, %12142, %c4096_i64, %c1_i64 : tensor<4xi64> + %12143 = stablehlo.dynamic_reshape %12141, %from_elements_4275 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12144 = stablehlo.dynamic_broadcast_in_dim %12139, %from_elements_4273, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4276 = tensor.dim %12144, %c1 : tensor<1x?x4096xi64> + %12145 = arith.index_cast %dim_4276 : index to i64 + %from_elements_4277 = tensor.from_elements %c1_i64, %12145, %c4096_i64, %c1_i64 : tensor<4xi64> + %12146 = stablehlo.dynamic_reshape %12144, %from_elements_4277 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12147 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4273, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4278 = tensor.dim %12147, %c1 : tensor<1x?x4096xi64> + %12148 = arith.index_cast %dim_4278 : index to i64 + %from_elements_4279 = tensor.from_elements %c1_i64, %12148, %c4096_i64, %c1_i64 : tensor<4xi64> + %12149 = stablehlo.dynamic_reshape %12147, %from_elements_4279 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12150 = stablehlo.concatenate %12143, %12146, %12149, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12151 = "stablehlo.gather"(%11821, %12150) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12152 = shape.shape_of %12151 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12153 = shape.num_elements %12152 : tensor<3xindex> -> index + %12154 = stablehlo.compute_reshape_shape %12153, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12155 = stablehlo.dynamic_reshape %12151, %12154 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12156 = stablehlo.dot %12155, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12157 = stablehlo.logistic %12156 : tensor + %12158 = shape.shape_of %12157 : tensor -> tensor<2xindex> + %12159 = shape.shape_of %12156 : tensor -> tensor<2xindex> + %12160 = shape.cstr_broadcastable %12158, %12159 : tensor<2xindex>, tensor<2xindex> + %12161 = shape.assuming %12160 -> (tensor) { + %19688 = shape.broadcast %12158, %12159 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12157, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12156, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12162 = shape.shape_of %12161 : tensor -> tensor<2xindex> + %12163 = shape.cstr_broadcastable %12162, %12159 : tensor<2xindex>, tensor<2xindex> + %12164 = shape.assuming %12163 -> (tensor) { + %19688 = shape.broadcast %12162, %12159 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12161, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12156, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12165 = stablehlo.dot %12164, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4280 = tensor.dim %12137, %c0 : tensor + %12166 = arith.index_cast %dim_4280 : index to i64 + %from_elements_4281 = tensor.from_elements %12166, %c1_i64 : tensor<2xi64> + %12167 = stablehlo.dynamic_reshape %12137, %from_elements_4281 : (tensor, tensor<2xi64>) -> tensor + %dim_4282 = tensor.dim %12134, %c0 : tensor + %12168 = arith.index_cast %dim_4282 : index to i64 + %from_elements_4283 = tensor.from_elements %12168, %c1_i64 : tensor<2xi64> + %12169 = stablehlo.dynamic_reshape %12134, %from_elements_4283 : (tensor, tensor<2xi64>) -> tensor + %12170 = stablehlo.concatenate %12167, %12169, dim = 1 : (tensor, tensor) -> tensor + %12171 = "stablehlo.gather"(%11850, %12170) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12172 = shape.shape_of %12165 : tensor -> tensor<2xindex> + %12173 = shape.shape_of %12171 : tensor -> tensor<2xindex> + %12174 = shape.cstr_broadcastable %12172, %12173 : tensor<2xindex>, tensor<2xindex> + %12175 = shape.assuming %12174 -> (tensor) { + %19688 = shape.broadcast %12172, %12173 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12165, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12171, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12176 = shape.shape_of %12175 : tensor -> tensor<2xindex> + %12177 = stablehlo.dynamic_broadcast_in_dim %12175, %12176, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12178 = stablehlo.dynamic_broadcast_in_dim %213, %12176, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12179 = stablehlo.multiply %12177, %12178 : tensor + %dim_4284 = tensor.dim %12139, %c0 : tensor + %12180 = arith.index_cast %dim_4284 : index to i64 + %dim_4285 = tensor.dim %12175, %c0 : tensor + %12181 = arith.index_cast %dim_4285 : index to i64 + %12182 = arith.maxsi %12180, %12181 : i64 + %12183 = arith.index_cast %12182 : i64 to index + %from_elements_4286 = tensor.from_elements %12183, %c4096 : tensor<2xindex> + %12184 = stablehlo.dynamic_broadcast_in_dim %12139, %from_elements_4286, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4287 = tensor.dim %12184, %c0 : tensor + %12185 = arith.index_cast %dim_4287 : index to i64 + %from_elements_4288 = tensor.from_elements %12185, %c4096_i64 : tensor<2xi64> + %12186 = stablehlo.real_dynamic_slice %12179, %c_22, %from_elements_4288, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4289 = tensor.from_elements %12185, %c4096_i64, %c1_i64 : tensor<3xi64> + %12187 = stablehlo.dynamic_reshape %12184, %from_elements_4289 : (tensor, tensor<3xi64>) -> tensor + %12188 = stablehlo.dynamic_iota %from_elements_4289, dim = 1 : (tensor<3xi64>) -> tensor + %12189 = stablehlo.concatenate %12187, %12188, dim = 2 : (tensor, tensor) -> tensor + %12190 = "stablehlo.scatter"(%12127, %12189, %12186) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12191 = stablehlo.slice %11810 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12192 = stablehlo.reshape %12191 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12193 = stablehlo.custom_call @byteir.non_zero(%12192) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4290 = tensor.dim %12193, %c0 : tensor + %12194 = arith.index_cast %dim_4290 : index to i64 + %from_elements_4291 = tensor.from_elements %12194, %c1_i64 : tensor<2xi64> + %12195 = stablehlo.real_dynamic_slice %12193, %c_22, %from_elements_4291, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4292 = tensor.dim %12195, %c0 : tensor + %12196 = arith.index_cast %dim_4292 : index to i64 + %from_elements_4293 = tensor.from_elements %12196 : tensor<1xi64> + %12197 = stablehlo.dynamic_reshape %12195, %from_elements_4293 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4294 = tensor.from_elements %12194, %c2_i64 : tensor<2xi64> + %12198 = stablehlo.real_dynamic_slice %12193, %c_24, %from_elements_4294, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4295 = tensor.dim %12198, %c0 : tensor + %12199 = arith.index_cast %dim_4295 : index to i64 + %from_elements_4296 = tensor.from_elements %12199 : tensor<1xi64> + %12200 = stablehlo.dynamic_reshape %12198, %from_elements_4296 : (tensor, tensor<1xi64>) -> tensor + %dim_4297 = tensor.dim %12200, %c0 : tensor + %12201 = arith.index_cast %dim_4297 : index to i64 + %from_elements_4298 = tensor.from_elements %12201, %c1_i64 : tensor<2xi64> + %12202 = stablehlo.dynamic_reshape %12200, %from_elements_4298 : (tensor, tensor<2xi64>) -> tensor + %dim_4299 = tensor.dim %12202, %c0 : tensor + %12203 = arith.index_cast %dim_4299 : index to i64 + %from_elements_4300 = tensor.from_elements %c1_i64, %12203, %c4096_i64 : tensor<3xi64> + %12204 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4300, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4301 = tensor.dim %12204, %c1 : tensor<1x?x4096xi64> + %12205 = arith.index_cast %dim_4301 : index to i64 + %from_elements_4302 = tensor.from_elements %c1_i64, %12205, %c4096_i64, %c1_i64 : tensor<4xi64> + %12206 = stablehlo.dynamic_reshape %12204, %from_elements_4302 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12207 = stablehlo.dynamic_broadcast_in_dim %12202, %from_elements_4300, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4303 = tensor.dim %12207, %c1 : tensor<1x?x4096xi64> + %12208 = arith.index_cast %dim_4303 : index to i64 + %from_elements_4304 = tensor.from_elements %c1_i64, %12208, %c4096_i64, %c1_i64 : tensor<4xi64> + %12209 = stablehlo.dynamic_reshape %12207, %from_elements_4304 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12210 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4300, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4305 = tensor.dim %12210, %c1 : tensor<1x?x4096xi64> + %12211 = arith.index_cast %dim_4305 : index to i64 + %from_elements_4306 = tensor.from_elements %c1_i64, %12211, %c4096_i64, %c1_i64 : tensor<4xi64> + %12212 = stablehlo.dynamic_reshape %12210, %from_elements_4306 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12213 = stablehlo.concatenate %12206, %12209, %12212, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12214 = "stablehlo.gather"(%11821, %12213) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12215 = shape.shape_of %12214 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12216 = shape.num_elements %12215 : tensor<3xindex> -> index + %12217 = stablehlo.compute_reshape_shape %12216, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12218 = stablehlo.dynamic_reshape %12214, %12217 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12219 = stablehlo.dot %12218, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12220 = stablehlo.logistic %12219 : tensor + %12221 = shape.shape_of %12220 : tensor -> tensor<2xindex> + %12222 = shape.shape_of %12219 : tensor -> tensor<2xindex> + %12223 = shape.cstr_broadcastable %12221, %12222 : tensor<2xindex>, tensor<2xindex> + %12224 = shape.assuming %12223 -> (tensor) { + %19688 = shape.broadcast %12221, %12222 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12220, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12219, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12225 = shape.shape_of %12224 : tensor -> tensor<2xindex> + %12226 = shape.cstr_broadcastable %12225, %12222 : tensor<2xindex>, tensor<2xindex> + %12227 = shape.assuming %12226 -> (tensor) { + %19688 = shape.broadcast %12225, %12222 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12224, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12219, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12228 = stablehlo.dot %12227, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4307 = tensor.dim %12200, %c0 : tensor + %12229 = arith.index_cast %dim_4307 : index to i64 + %from_elements_4308 = tensor.from_elements %12229, %c1_i64 : tensor<2xi64> + %12230 = stablehlo.dynamic_reshape %12200, %from_elements_4308 : (tensor, tensor<2xi64>) -> tensor + %dim_4309 = tensor.dim %12197, %c0 : tensor + %12231 = arith.index_cast %dim_4309 : index to i64 + %from_elements_4310 = tensor.from_elements %12231, %c1_i64 : tensor<2xi64> + %12232 = stablehlo.dynamic_reshape %12197, %from_elements_4310 : (tensor, tensor<2xi64>) -> tensor + %12233 = stablehlo.concatenate %12230, %12232, dim = 1 : (tensor, tensor) -> tensor + %12234 = "stablehlo.gather"(%11850, %12233) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12235 = shape.shape_of %12228 : tensor -> tensor<2xindex> + %12236 = shape.shape_of %12234 : tensor -> tensor<2xindex> + %12237 = shape.cstr_broadcastable %12235, %12236 : tensor<2xindex>, tensor<2xindex> + %12238 = shape.assuming %12237 -> (tensor) { + %19688 = shape.broadcast %12235, %12236 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12228, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12234, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12239 = shape.shape_of %12238 : tensor -> tensor<2xindex> + %12240 = stablehlo.dynamic_broadcast_in_dim %12238, %12239, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12241 = stablehlo.dynamic_broadcast_in_dim %213, %12239, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12242 = stablehlo.multiply %12240, %12241 : tensor + %dim_4311 = tensor.dim %12202, %c0 : tensor + %12243 = arith.index_cast %dim_4311 : index to i64 + %dim_4312 = tensor.dim %12238, %c0 : tensor + %12244 = arith.index_cast %dim_4312 : index to i64 + %12245 = arith.maxsi %12243, %12244 : i64 + %12246 = arith.index_cast %12245 : i64 to index + %from_elements_4313 = tensor.from_elements %12246, %c4096 : tensor<2xindex> + %12247 = stablehlo.dynamic_broadcast_in_dim %12202, %from_elements_4313, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4314 = tensor.dim %12247, %c0 : tensor + %12248 = arith.index_cast %dim_4314 : index to i64 + %from_elements_4315 = tensor.from_elements %12248, %c4096_i64 : tensor<2xi64> + %12249 = stablehlo.real_dynamic_slice %12242, %c_22, %from_elements_4315, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4316 = tensor.from_elements %12248, %c4096_i64, %c1_i64 : tensor<3xi64> + %12250 = stablehlo.dynamic_reshape %12247, %from_elements_4316 : (tensor, tensor<3xi64>) -> tensor + %12251 = stablehlo.dynamic_iota %from_elements_4316, dim = 1 : (tensor<3xi64>) -> tensor + %12252 = stablehlo.concatenate %12250, %12251, dim = 2 : (tensor, tensor) -> tensor + %12253 = "stablehlo.scatter"(%12190, %12252, %12249) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12254 = stablehlo.slice %11810 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12255 = stablehlo.reshape %12254 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12256 = stablehlo.custom_call @byteir.non_zero(%12255) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4317 = tensor.dim %12256, %c0 : tensor + %12257 = arith.index_cast %dim_4317 : index to i64 + %from_elements_4318 = tensor.from_elements %12257, %c1_i64 : tensor<2xi64> + %12258 = stablehlo.real_dynamic_slice %12256, %c_22, %from_elements_4318, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4319 = tensor.dim %12258, %c0 : tensor + %12259 = arith.index_cast %dim_4319 : index to i64 + %from_elements_4320 = tensor.from_elements %12259 : tensor<1xi64> + %12260 = stablehlo.dynamic_reshape %12258, %from_elements_4320 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4321 = tensor.from_elements %12257, %c2_i64 : tensor<2xi64> + %12261 = stablehlo.real_dynamic_slice %12256, %c_24, %from_elements_4321, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4322 = tensor.dim %12261, %c0 : tensor + %12262 = arith.index_cast %dim_4322 : index to i64 + %from_elements_4323 = tensor.from_elements %12262 : tensor<1xi64> + %12263 = stablehlo.dynamic_reshape %12261, %from_elements_4323 : (tensor, tensor<1xi64>) -> tensor + %dim_4324 = tensor.dim %12263, %c0 : tensor + %12264 = arith.index_cast %dim_4324 : index to i64 + %from_elements_4325 = tensor.from_elements %12264, %c1_i64 : tensor<2xi64> + %12265 = stablehlo.dynamic_reshape %12263, %from_elements_4325 : (tensor, tensor<2xi64>) -> tensor + %dim_4326 = tensor.dim %12265, %c0 : tensor + %12266 = arith.index_cast %dim_4326 : index to i64 + %from_elements_4327 = tensor.from_elements %c1_i64, %12266, %c4096_i64 : tensor<3xi64> + %12267 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4327, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4328 = tensor.dim %12267, %c1 : tensor<1x?x4096xi64> + %12268 = arith.index_cast %dim_4328 : index to i64 + %from_elements_4329 = tensor.from_elements %c1_i64, %12268, %c4096_i64, %c1_i64 : tensor<4xi64> + %12269 = stablehlo.dynamic_reshape %12267, %from_elements_4329 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12270 = stablehlo.dynamic_broadcast_in_dim %12265, %from_elements_4327, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4330 = tensor.dim %12270, %c1 : tensor<1x?x4096xi64> + %12271 = arith.index_cast %dim_4330 : index to i64 + %from_elements_4331 = tensor.from_elements %c1_i64, %12271, %c4096_i64, %c1_i64 : tensor<4xi64> + %12272 = stablehlo.dynamic_reshape %12270, %from_elements_4331 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12273 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4327, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4332 = tensor.dim %12273, %c1 : tensor<1x?x4096xi64> + %12274 = arith.index_cast %dim_4332 : index to i64 + %from_elements_4333 = tensor.from_elements %c1_i64, %12274, %c4096_i64, %c1_i64 : tensor<4xi64> + %12275 = stablehlo.dynamic_reshape %12273, %from_elements_4333 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12276 = stablehlo.concatenate %12269, %12272, %12275, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12277 = "stablehlo.gather"(%11821, %12276) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12278 = shape.shape_of %12277 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12279 = shape.num_elements %12278 : tensor<3xindex> -> index + %12280 = stablehlo.compute_reshape_shape %12279, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12281 = stablehlo.dynamic_reshape %12277, %12280 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12282 = stablehlo.dot %12281, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12283 = stablehlo.logistic %12282 : tensor + %12284 = shape.shape_of %12283 : tensor -> tensor<2xindex> + %12285 = shape.shape_of %12282 : tensor -> tensor<2xindex> + %12286 = shape.cstr_broadcastable %12284, %12285 : tensor<2xindex>, tensor<2xindex> + %12287 = shape.assuming %12286 -> (tensor) { + %19688 = shape.broadcast %12284, %12285 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12283, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12282, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12288 = shape.shape_of %12287 : tensor -> tensor<2xindex> + %12289 = shape.cstr_broadcastable %12288, %12285 : tensor<2xindex>, tensor<2xindex> + %12290 = shape.assuming %12289 -> (tensor) { + %19688 = shape.broadcast %12288, %12285 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12287, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12282, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12291 = stablehlo.dot %12290, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4334 = tensor.dim %12263, %c0 : tensor + %12292 = arith.index_cast %dim_4334 : index to i64 + %from_elements_4335 = tensor.from_elements %12292, %c1_i64 : tensor<2xi64> + %12293 = stablehlo.dynamic_reshape %12263, %from_elements_4335 : (tensor, tensor<2xi64>) -> tensor + %dim_4336 = tensor.dim %12260, %c0 : tensor + %12294 = arith.index_cast %dim_4336 : index to i64 + %from_elements_4337 = tensor.from_elements %12294, %c1_i64 : tensor<2xi64> + %12295 = stablehlo.dynamic_reshape %12260, %from_elements_4337 : (tensor, tensor<2xi64>) -> tensor + %12296 = stablehlo.concatenate %12293, %12295, dim = 1 : (tensor, tensor) -> tensor + %12297 = "stablehlo.gather"(%11850, %12296) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12298 = shape.shape_of %12291 : tensor -> tensor<2xindex> + %12299 = shape.shape_of %12297 : tensor -> tensor<2xindex> + %12300 = shape.cstr_broadcastable %12298, %12299 : tensor<2xindex>, tensor<2xindex> + %12301 = shape.assuming %12300 -> (tensor) { + %19688 = shape.broadcast %12298, %12299 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12291, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12297, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12302 = shape.shape_of %12301 : tensor -> tensor<2xindex> + %12303 = stablehlo.dynamic_broadcast_in_dim %12301, %12302, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12304 = stablehlo.dynamic_broadcast_in_dim %213, %12302, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12305 = stablehlo.multiply %12303, %12304 : tensor + %dim_4338 = tensor.dim %12265, %c0 : tensor + %12306 = arith.index_cast %dim_4338 : index to i64 + %dim_4339 = tensor.dim %12301, %c0 : tensor + %12307 = arith.index_cast %dim_4339 : index to i64 + %12308 = arith.maxsi %12306, %12307 : i64 + %12309 = arith.index_cast %12308 : i64 to index + %from_elements_4340 = tensor.from_elements %12309, %c4096 : tensor<2xindex> + %12310 = stablehlo.dynamic_broadcast_in_dim %12265, %from_elements_4340, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4341 = tensor.dim %12310, %c0 : tensor + %12311 = arith.index_cast %dim_4341 : index to i64 + %from_elements_4342 = tensor.from_elements %12311, %c4096_i64 : tensor<2xi64> + %12312 = stablehlo.real_dynamic_slice %12305, %c_22, %from_elements_4342, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4343 = tensor.from_elements %12311, %c4096_i64, %c1_i64 : tensor<3xi64> + %12313 = stablehlo.dynamic_reshape %12310, %from_elements_4343 : (tensor, tensor<3xi64>) -> tensor + %12314 = stablehlo.dynamic_iota %from_elements_4343, dim = 1 : (tensor<3xi64>) -> tensor + %12315 = stablehlo.concatenate %12313, %12314, dim = 2 : (tensor, tensor) -> tensor + %12316 = "stablehlo.scatter"(%12253, %12315, %12312) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12317 = stablehlo.reshape %12316 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %12318 = stablehlo.add %11783, %12317 : tensor<3x1x4096xf32> + %12319 = stablehlo.broadcast_in_dim %12318, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %12320 = stablehlo.power %12319, %15 : tensor<3x1x4096xf32> + %12321 = stablehlo.reduce(%12320 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %12322 = stablehlo.reshape %12321 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %12323 = stablehlo.broadcast_in_dim %12322, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %12324 = stablehlo.divide %12323, %21 : tensor<3x1x1xf32> + %12325 = stablehlo.broadcast_in_dim %12324, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %12326 = stablehlo.add %12325, %25 : tensor<3x1x1xf32> + %12327 = stablehlo.rsqrt %12326 : tensor<3x1x1xf32> + %12328 = stablehlo.broadcast_in_dim %12327, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %12329 = stablehlo.multiply %12319, %12328 : tensor<3x1x4096xf32> + %12330 = stablehlo.broadcast_in_dim %12329, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %12331 = stablehlo.multiply %12330, %31 : tensor<3x1x4096xf32> + %12332 = stablehlo.reshape %12331 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %12333 = stablehlo.dot %12332, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %12334 = stablehlo.reshape %12333 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %12335 = stablehlo.dot %12332, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %12336 = stablehlo.reshape %12335 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %12337 = stablehlo.reshape %12334 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %12338 = stablehlo.transpose %12337, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %12339 = stablehlo.reshape %12336 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %12340 = stablehlo.transpose %12339, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %12341 = stablehlo.slice %arg40 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %12342 = stablehlo.slice %arg41 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %12343 = "stablehlo.gather"(%12341, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %12344 = stablehlo.reshape %12343 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %12345 = "stablehlo.gather"(%12342, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %12346 = stablehlo.reshape %12345 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %12347 = stablehlo.broadcast_in_dim %12338, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %12348 = stablehlo.broadcast_in_dim %12344, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %12349 = stablehlo.multiply %12347, %12348 : tensor<3x32x1x128xf32> + %12350 = stablehlo.slice %12338 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %12351 = stablehlo.slice %12338 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %12352 = stablehlo.negate %12351 : tensor<3x32x1x64xf32> + %12353 = stablehlo.concatenate %12352, %12350, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %12354 = stablehlo.broadcast_in_dim %12353, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %12355 = stablehlo.broadcast_in_dim %12346, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %12356 = stablehlo.multiply %12354, %12355 : tensor<3x32x1x128xf32> + %12357 = stablehlo.add %12349, %12356 : tensor<3x32x1x128xf32> + %12358 = stablehlo.broadcast_in_dim %12340, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %12359 = stablehlo.broadcast_in_dim %12344, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %12360 = stablehlo.multiply %12358, %12359 : tensor<3x8x1x128xf32> + %12361 = stablehlo.slice %12340 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %12362 = stablehlo.slice %12340 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %12363 = stablehlo.negate %12362 : tensor<3x8x1x64xf32> + %12364 = stablehlo.concatenate %12363, %12361, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %12365 = stablehlo.broadcast_in_dim %12364, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %12366 = stablehlo.broadcast_in_dim %12346, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %12367 = stablehlo.multiply %12365, %12366 : tensor<3x8x1x128xf32> + %12368 = stablehlo.add %12360, %12367 : tensor<3x8x1x128xf32> + %12369 = stablehlo.concatenate %arg105, %12368, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %12370 = stablehlo.concatenate %arg106, %12340, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %12371 = stablehlo.reshape %12369 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %12372 = stablehlo.broadcast_in_dim %12371, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %12373 = stablehlo.reshape %12372 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %12374 = stablehlo.reshape %12370 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %12375 = stablehlo.broadcast_in_dim %12374, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %12376 = stablehlo.reshape %12375 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %12377 = stablehlo.transpose %12373, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %12378 = stablehlo.reshape %12357 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %12379 = stablehlo.reshape %12377 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %12380 = stablehlo.broadcast_in_dim %12379, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %12381 = stablehlo.dot_general %12378, %12380, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %12382 = stablehlo.reshape %12381 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %12383 = stablehlo.broadcast_in_dim %12382, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %12384 = stablehlo.divide %12383, %89 : tensor<3x32x1x8xf32> + %12385 = stablehlo.custom_call @byteir.softmax(%12384) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %12386 = stablehlo.reshape %12385 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %12387 = stablehlo.reshape %12376 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %12388 = stablehlo.broadcast_in_dim %12387, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %12389 = stablehlo.dot_general %12386, %12388, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %12390 = stablehlo.reshape %12389 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %12391 = stablehlo.transpose %12390, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %12392 = stablehlo.reshape %12391 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %12393 = stablehlo.reshape %12392 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %12394 = stablehlo.dot %12393, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %12395 = stablehlo.reshape %12394 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %12396 = stablehlo.add %12318, %12395 : tensor<3x1x4096xf32> + %12397 = stablehlo.broadcast_in_dim %12396, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %12398 = stablehlo.power %12397, %15 : tensor<3x1x4096xf32> + %12399 = stablehlo.reduce(%12398 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %12400 = stablehlo.reshape %12399 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %12401 = stablehlo.broadcast_in_dim %12400, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %12402 = stablehlo.divide %12401, %21 : tensor<3x1x1xf32> + %12403 = stablehlo.broadcast_in_dim %12402, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %12404 = stablehlo.add %12403, %25 : tensor<3x1x1xf32> + %12405 = stablehlo.rsqrt %12404 : tensor<3x1x1xf32> + %12406 = stablehlo.broadcast_in_dim %12405, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %12407 = stablehlo.multiply %12397, %12406 : tensor<3x1x4096xf32> + %12408 = stablehlo.broadcast_in_dim %12407, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %12409 = stablehlo.multiply %12408, %31 : tensor<3x1x4096xf32> + %12410 = stablehlo.reshape %12409 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %12411 = stablehlo.dot %12410, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %12412 = stablehlo.custom_call @byteir.softmax(%12411) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %12413:2 = stablehlo.custom_call @byteir.top_k(%12412) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %12414 = stablehlo.reduce(%12413#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %12415 = stablehlo.reshape %12414 : (tensor<3xf32>) -> tensor<3x1xf32> + %12416 = stablehlo.broadcast_in_dim %12413#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %12417 = stablehlo.broadcast_in_dim %12415, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %12418 = stablehlo.divide %12416, %12417 : tensor<3x2xf32> + %12419 = stablehlo.reshape %12413#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %12420 = stablehlo.broadcast_in_dim %12419, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %12421 = stablehlo.compare EQ, %12420, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %12422 = stablehlo.convert %12421 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %12423 = stablehlo.transpose %12422, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %12424 = stablehlo.slice %12423 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12425 = stablehlo.reshape %12424 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12426 = stablehlo.custom_call @byteir.non_zero(%12425) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4344 = tensor.dim %12426, %c0 : tensor + %12427 = arith.index_cast %dim_4344 : index to i64 + %from_elements_4345 = tensor.from_elements %12427, %c1_i64 : tensor<2xi64> + %12428 = stablehlo.real_dynamic_slice %12426, %c_22, %from_elements_4345, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4346 = tensor.dim %12428, %c0 : tensor + %12429 = arith.index_cast %dim_4346 : index to i64 + %from_elements_4347 = tensor.from_elements %12429 : tensor<1xi64> + %12430 = stablehlo.dynamic_reshape %12428, %from_elements_4347 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4348 = tensor.from_elements %12427, %c2_i64 : tensor<2xi64> + %12431 = stablehlo.real_dynamic_slice %12426, %c_24, %from_elements_4348, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4349 = tensor.dim %12431, %c0 : tensor + %12432 = arith.index_cast %dim_4349 : index to i64 + %from_elements_4350 = tensor.from_elements %12432 : tensor<1xi64> + %12433 = stablehlo.dynamic_reshape %12431, %from_elements_4350 : (tensor, tensor<1xi64>) -> tensor + %12434 = stablehlo.reshape %12410 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_4351 = tensor.dim %12433, %c0 : tensor + %12435 = arith.index_cast %dim_4351 : index to i64 + %from_elements_4352 = tensor.from_elements %12435, %c1_i64 : tensor<2xi64> + %12436 = stablehlo.dynamic_reshape %12433, %from_elements_4352 : (tensor, tensor<2xi64>) -> tensor + %dim_4353 = tensor.dim %12436, %c0 : tensor + %12437 = arith.index_cast %dim_4353 : index to i64 + %from_elements_4354 = tensor.from_elements %c1_i64, %12437, %c4096_i64 : tensor<3xi64> + %12438 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4354, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4355 = tensor.dim %12438, %c1 : tensor<1x?x4096xi64> + %12439 = arith.index_cast %dim_4355 : index to i64 + %from_elements_4356 = tensor.from_elements %c1_i64, %12439, %c4096_i64, %c1_i64 : tensor<4xi64> + %12440 = stablehlo.dynamic_reshape %12438, %from_elements_4356 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12441 = stablehlo.dynamic_broadcast_in_dim %12436, %from_elements_4354, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4357 = tensor.dim %12441, %c1 : tensor<1x?x4096xi64> + %12442 = arith.index_cast %dim_4357 : index to i64 + %from_elements_4358 = tensor.from_elements %c1_i64, %12442, %c4096_i64, %c1_i64 : tensor<4xi64> + %12443 = stablehlo.dynamic_reshape %12441, %from_elements_4358 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12444 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4354, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4359 = tensor.dim %12444, %c1 : tensor<1x?x4096xi64> + %12445 = arith.index_cast %dim_4359 : index to i64 + %from_elements_4360 = tensor.from_elements %c1_i64, %12445, %c4096_i64, %c1_i64 : tensor<4xi64> + %12446 = stablehlo.dynamic_reshape %12444, %from_elements_4360 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12447 = stablehlo.concatenate %12440, %12443, %12446, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12448 = "stablehlo.gather"(%12434, %12447) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12449 = shape.shape_of %12448 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12450 = shape.num_elements %12449 : tensor<3xindex> -> index + %12451 = stablehlo.compute_reshape_shape %12450, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12452 = stablehlo.dynamic_reshape %12448, %12451 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12453 = stablehlo.dot %12452, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12454 = stablehlo.logistic %12453 : tensor + %12455 = shape.shape_of %12454 : tensor -> tensor<2xindex> + %12456 = shape.shape_of %12453 : tensor -> tensor<2xindex> + %12457 = shape.cstr_broadcastable %12455, %12456 : tensor<2xindex>, tensor<2xindex> + %12458 = shape.assuming %12457 -> (tensor) { + %19688 = shape.broadcast %12455, %12456 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12454, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12453, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12459 = shape.shape_of %12458 : tensor -> tensor<2xindex> + %12460 = shape.cstr_broadcastable %12459, %12456 : tensor<2xindex>, tensor<2xindex> + %12461 = shape.assuming %12460 -> (tensor) { + %19688 = shape.broadcast %12459, %12456 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12458, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12453, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12462 = stablehlo.dot %12461, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %12463 = stablehlo.reshape %12418 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_4361 = tensor.dim %12433, %c0 : tensor + %12464 = arith.index_cast %dim_4361 : index to i64 + %from_elements_4362 = tensor.from_elements %12464, %c1_i64 : tensor<2xi64> + %12465 = stablehlo.dynamic_reshape %12433, %from_elements_4362 : (tensor, tensor<2xi64>) -> tensor + %dim_4363 = tensor.dim %12430, %c0 : tensor + %12466 = arith.index_cast %dim_4363 : index to i64 + %from_elements_4364 = tensor.from_elements %12466, %c1_i64 : tensor<2xi64> + %12467 = stablehlo.dynamic_reshape %12430, %from_elements_4364 : (tensor, tensor<2xi64>) -> tensor + %12468 = stablehlo.concatenate %12465, %12467, dim = 1 : (tensor, tensor) -> tensor + %12469 = "stablehlo.gather"(%12463, %12468) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12470 = shape.shape_of %12462 : tensor -> tensor<2xindex> + %12471 = shape.shape_of %12469 : tensor -> tensor<2xindex> + %12472 = shape.cstr_broadcastable %12470, %12471 : tensor<2xindex>, tensor<2xindex> + %12473 = shape.assuming %12472 -> (tensor) { + %19688 = shape.broadcast %12470, %12471 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12462, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12469, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12474 = shape.shape_of %12473 : tensor -> tensor<2xindex> + %12475 = stablehlo.dynamic_broadcast_in_dim %12473, %12474, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12476 = stablehlo.dynamic_broadcast_in_dim %213, %12474, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12477 = stablehlo.multiply %12475, %12476 : tensor + %dim_4365 = tensor.dim %12436, %c0 : tensor + %12478 = arith.index_cast %dim_4365 : index to i64 + %dim_4366 = tensor.dim %12473, %c0 : tensor + %12479 = arith.index_cast %dim_4366 : index to i64 + %12480 = arith.maxsi %12478, %12479 : i64 + %12481 = arith.index_cast %12480 : i64 to index + %from_elements_4367 = tensor.from_elements %12481, %c4096 : tensor<2xindex> + %12482 = stablehlo.dynamic_broadcast_in_dim %12436, %from_elements_4367, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4368 = tensor.dim %12482, %c0 : tensor + %12483 = arith.index_cast %dim_4368 : index to i64 + %from_elements_4369 = tensor.from_elements %12483, %c4096_i64 : tensor<2xi64> + %12484 = stablehlo.real_dynamic_slice %12477, %c_22, %from_elements_4369, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4370 = tensor.from_elements %12483, %c4096_i64, %c1_i64 : tensor<3xi64> + %12485 = stablehlo.dynamic_reshape %12482, %from_elements_4370 : (tensor, tensor<3xi64>) -> tensor + %12486 = stablehlo.dynamic_iota %from_elements_4370, dim = 1 : (tensor<3xi64>) -> tensor + %12487 = stablehlo.concatenate %12485, %12486, dim = 2 : (tensor, tensor) -> tensor + %12488 = "stablehlo.scatter"(%cst_2, %12487, %12484) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12489 = stablehlo.slice %12423 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12490 = stablehlo.reshape %12489 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12491 = stablehlo.custom_call @byteir.non_zero(%12490) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4371 = tensor.dim %12491, %c0 : tensor + %12492 = arith.index_cast %dim_4371 : index to i64 + %from_elements_4372 = tensor.from_elements %12492, %c1_i64 : tensor<2xi64> + %12493 = stablehlo.real_dynamic_slice %12491, %c_22, %from_elements_4372, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4373 = tensor.dim %12493, %c0 : tensor + %12494 = arith.index_cast %dim_4373 : index to i64 + %from_elements_4374 = tensor.from_elements %12494 : tensor<1xi64> + %12495 = stablehlo.dynamic_reshape %12493, %from_elements_4374 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4375 = tensor.from_elements %12492, %c2_i64 : tensor<2xi64> + %12496 = stablehlo.real_dynamic_slice %12491, %c_24, %from_elements_4375, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4376 = tensor.dim %12496, %c0 : tensor + %12497 = arith.index_cast %dim_4376 : index to i64 + %from_elements_4377 = tensor.from_elements %12497 : tensor<1xi64> + %12498 = stablehlo.dynamic_reshape %12496, %from_elements_4377 : (tensor, tensor<1xi64>) -> tensor + %dim_4378 = tensor.dim %12498, %c0 : tensor + %12499 = arith.index_cast %dim_4378 : index to i64 + %from_elements_4379 = tensor.from_elements %12499, %c1_i64 : tensor<2xi64> + %12500 = stablehlo.dynamic_reshape %12498, %from_elements_4379 : (tensor, tensor<2xi64>) -> tensor + %dim_4380 = tensor.dim %12500, %c0 : tensor + %12501 = arith.index_cast %dim_4380 : index to i64 + %from_elements_4381 = tensor.from_elements %c1_i64, %12501, %c4096_i64 : tensor<3xi64> + %12502 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4381, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4382 = tensor.dim %12502, %c1 : tensor<1x?x4096xi64> + %12503 = arith.index_cast %dim_4382 : index to i64 + %from_elements_4383 = tensor.from_elements %c1_i64, %12503, %c4096_i64, %c1_i64 : tensor<4xi64> + %12504 = stablehlo.dynamic_reshape %12502, %from_elements_4383 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12505 = stablehlo.dynamic_broadcast_in_dim %12500, %from_elements_4381, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4384 = tensor.dim %12505, %c1 : tensor<1x?x4096xi64> + %12506 = arith.index_cast %dim_4384 : index to i64 + %from_elements_4385 = tensor.from_elements %c1_i64, %12506, %c4096_i64, %c1_i64 : tensor<4xi64> + %12507 = stablehlo.dynamic_reshape %12505, %from_elements_4385 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12508 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4381, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4386 = tensor.dim %12508, %c1 : tensor<1x?x4096xi64> + %12509 = arith.index_cast %dim_4386 : index to i64 + %from_elements_4387 = tensor.from_elements %c1_i64, %12509, %c4096_i64, %c1_i64 : tensor<4xi64> + %12510 = stablehlo.dynamic_reshape %12508, %from_elements_4387 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12511 = stablehlo.concatenate %12504, %12507, %12510, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12512 = "stablehlo.gather"(%12434, %12511) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12513 = shape.shape_of %12512 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12514 = shape.num_elements %12513 : tensor<3xindex> -> index + %12515 = stablehlo.compute_reshape_shape %12514, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12516 = stablehlo.dynamic_reshape %12512, %12515 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12517 = stablehlo.dot %12516, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12518 = stablehlo.logistic %12517 : tensor + %12519 = shape.shape_of %12518 : tensor -> tensor<2xindex> + %12520 = shape.shape_of %12517 : tensor -> tensor<2xindex> + %12521 = shape.cstr_broadcastable %12519, %12520 : tensor<2xindex>, tensor<2xindex> + %12522 = shape.assuming %12521 -> (tensor) { + %19688 = shape.broadcast %12519, %12520 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12518, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12517, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12523 = shape.shape_of %12522 : tensor -> tensor<2xindex> + %12524 = shape.cstr_broadcastable %12523, %12520 : tensor<2xindex>, tensor<2xindex> + %12525 = shape.assuming %12524 -> (tensor) { + %19688 = shape.broadcast %12523, %12520 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12522, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12517, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12526 = stablehlo.dot %12525, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4388 = tensor.dim %12498, %c0 : tensor + %12527 = arith.index_cast %dim_4388 : index to i64 + %from_elements_4389 = tensor.from_elements %12527, %c1_i64 : tensor<2xi64> + %12528 = stablehlo.dynamic_reshape %12498, %from_elements_4389 : (tensor, tensor<2xi64>) -> tensor + %dim_4390 = tensor.dim %12495, %c0 : tensor + %12529 = arith.index_cast %dim_4390 : index to i64 + %from_elements_4391 = tensor.from_elements %12529, %c1_i64 : tensor<2xi64> + %12530 = stablehlo.dynamic_reshape %12495, %from_elements_4391 : (tensor, tensor<2xi64>) -> tensor + %12531 = stablehlo.concatenate %12528, %12530, dim = 1 : (tensor, tensor) -> tensor + %12532 = "stablehlo.gather"(%12463, %12531) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12533 = shape.shape_of %12526 : tensor -> tensor<2xindex> + %12534 = shape.shape_of %12532 : tensor -> tensor<2xindex> + %12535 = shape.cstr_broadcastable %12533, %12534 : tensor<2xindex>, tensor<2xindex> + %12536 = shape.assuming %12535 -> (tensor) { + %19688 = shape.broadcast %12533, %12534 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12526, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12532, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12537 = shape.shape_of %12536 : tensor -> tensor<2xindex> + %12538 = stablehlo.dynamic_broadcast_in_dim %12536, %12537, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12539 = stablehlo.dynamic_broadcast_in_dim %213, %12537, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12540 = stablehlo.multiply %12538, %12539 : tensor + %dim_4392 = tensor.dim %12500, %c0 : tensor + %12541 = arith.index_cast %dim_4392 : index to i64 + %dim_4393 = tensor.dim %12536, %c0 : tensor + %12542 = arith.index_cast %dim_4393 : index to i64 + %12543 = arith.maxsi %12541, %12542 : i64 + %12544 = arith.index_cast %12543 : i64 to index + %from_elements_4394 = tensor.from_elements %12544, %c4096 : tensor<2xindex> + %12545 = stablehlo.dynamic_broadcast_in_dim %12500, %from_elements_4394, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4395 = tensor.dim %12545, %c0 : tensor + %12546 = arith.index_cast %dim_4395 : index to i64 + %from_elements_4396 = tensor.from_elements %12546, %c4096_i64 : tensor<2xi64> + %12547 = stablehlo.real_dynamic_slice %12540, %c_22, %from_elements_4396, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4397 = tensor.from_elements %12546, %c4096_i64, %c1_i64 : tensor<3xi64> + %12548 = stablehlo.dynamic_reshape %12545, %from_elements_4397 : (tensor, tensor<3xi64>) -> tensor + %12549 = stablehlo.dynamic_iota %from_elements_4397, dim = 1 : (tensor<3xi64>) -> tensor + %12550 = stablehlo.concatenate %12548, %12549, dim = 2 : (tensor, tensor) -> tensor + %12551 = "stablehlo.scatter"(%12488, %12550, %12547) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12552 = stablehlo.slice %12423 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12553 = stablehlo.reshape %12552 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12554 = stablehlo.custom_call @byteir.non_zero(%12553) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4398 = tensor.dim %12554, %c0 : tensor + %12555 = arith.index_cast %dim_4398 : index to i64 + %from_elements_4399 = tensor.from_elements %12555, %c1_i64 : tensor<2xi64> + %12556 = stablehlo.real_dynamic_slice %12554, %c_22, %from_elements_4399, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4400 = tensor.dim %12556, %c0 : tensor + %12557 = arith.index_cast %dim_4400 : index to i64 + %from_elements_4401 = tensor.from_elements %12557 : tensor<1xi64> + %12558 = stablehlo.dynamic_reshape %12556, %from_elements_4401 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4402 = tensor.from_elements %12555, %c2_i64 : tensor<2xi64> + %12559 = stablehlo.real_dynamic_slice %12554, %c_24, %from_elements_4402, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4403 = tensor.dim %12559, %c0 : tensor + %12560 = arith.index_cast %dim_4403 : index to i64 + %from_elements_4404 = tensor.from_elements %12560 : tensor<1xi64> + %12561 = stablehlo.dynamic_reshape %12559, %from_elements_4404 : (tensor, tensor<1xi64>) -> tensor + %dim_4405 = tensor.dim %12561, %c0 : tensor + %12562 = arith.index_cast %dim_4405 : index to i64 + %from_elements_4406 = tensor.from_elements %12562, %c1_i64 : tensor<2xi64> + %12563 = stablehlo.dynamic_reshape %12561, %from_elements_4406 : (tensor, tensor<2xi64>) -> tensor + %dim_4407 = tensor.dim %12563, %c0 : tensor + %12564 = arith.index_cast %dim_4407 : index to i64 + %from_elements_4408 = tensor.from_elements %c1_i64, %12564, %c4096_i64 : tensor<3xi64> + %12565 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4408, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4409 = tensor.dim %12565, %c1 : tensor<1x?x4096xi64> + %12566 = arith.index_cast %dim_4409 : index to i64 + %from_elements_4410 = tensor.from_elements %c1_i64, %12566, %c4096_i64, %c1_i64 : tensor<4xi64> + %12567 = stablehlo.dynamic_reshape %12565, %from_elements_4410 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12568 = stablehlo.dynamic_broadcast_in_dim %12563, %from_elements_4408, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4411 = tensor.dim %12568, %c1 : tensor<1x?x4096xi64> + %12569 = arith.index_cast %dim_4411 : index to i64 + %from_elements_4412 = tensor.from_elements %c1_i64, %12569, %c4096_i64, %c1_i64 : tensor<4xi64> + %12570 = stablehlo.dynamic_reshape %12568, %from_elements_4412 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12571 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4408, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4413 = tensor.dim %12571, %c1 : tensor<1x?x4096xi64> + %12572 = arith.index_cast %dim_4413 : index to i64 + %from_elements_4414 = tensor.from_elements %c1_i64, %12572, %c4096_i64, %c1_i64 : tensor<4xi64> + %12573 = stablehlo.dynamic_reshape %12571, %from_elements_4414 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12574 = stablehlo.concatenate %12567, %12570, %12573, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12575 = "stablehlo.gather"(%12434, %12574) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12576 = shape.shape_of %12575 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12577 = shape.num_elements %12576 : tensor<3xindex> -> index + %12578 = stablehlo.compute_reshape_shape %12577, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12579 = stablehlo.dynamic_reshape %12575, %12578 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12580 = stablehlo.dot %12579, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12581 = stablehlo.logistic %12580 : tensor + %12582 = shape.shape_of %12581 : tensor -> tensor<2xindex> + %12583 = shape.shape_of %12580 : tensor -> tensor<2xindex> + %12584 = shape.cstr_broadcastable %12582, %12583 : tensor<2xindex>, tensor<2xindex> + %12585 = shape.assuming %12584 -> (tensor) { + %19688 = shape.broadcast %12582, %12583 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12581, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12580, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12586 = shape.shape_of %12585 : tensor -> tensor<2xindex> + %12587 = shape.cstr_broadcastable %12586, %12583 : tensor<2xindex>, tensor<2xindex> + %12588 = shape.assuming %12587 -> (tensor) { + %19688 = shape.broadcast %12586, %12583 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12585, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12580, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12589 = stablehlo.dot %12588, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4415 = tensor.dim %12561, %c0 : tensor + %12590 = arith.index_cast %dim_4415 : index to i64 + %from_elements_4416 = tensor.from_elements %12590, %c1_i64 : tensor<2xi64> + %12591 = stablehlo.dynamic_reshape %12561, %from_elements_4416 : (tensor, tensor<2xi64>) -> tensor + %dim_4417 = tensor.dim %12558, %c0 : tensor + %12592 = arith.index_cast %dim_4417 : index to i64 + %from_elements_4418 = tensor.from_elements %12592, %c1_i64 : tensor<2xi64> + %12593 = stablehlo.dynamic_reshape %12558, %from_elements_4418 : (tensor, tensor<2xi64>) -> tensor + %12594 = stablehlo.concatenate %12591, %12593, dim = 1 : (tensor, tensor) -> tensor + %12595 = "stablehlo.gather"(%12463, %12594) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12596 = shape.shape_of %12589 : tensor -> tensor<2xindex> + %12597 = shape.shape_of %12595 : tensor -> tensor<2xindex> + %12598 = shape.cstr_broadcastable %12596, %12597 : tensor<2xindex>, tensor<2xindex> + %12599 = shape.assuming %12598 -> (tensor) { + %19688 = shape.broadcast %12596, %12597 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12589, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12595, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12600 = shape.shape_of %12599 : tensor -> tensor<2xindex> + %12601 = stablehlo.dynamic_broadcast_in_dim %12599, %12600, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12602 = stablehlo.dynamic_broadcast_in_dim %213, %12600, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12603 = stablehlo.multiply %12601, %12602 : tensor + %dim_4419 = tensor.dim %12563, %c0 : tensor + %12604 = arith.index_cast %dim_4419 : index to i64 + %dim_4420 = tensor.dim %12599, %c0 : tensor + %12605 = arith.index_cast %dim_4420 : index to i64 + %12606 = arith.maxsi %12604, %12605 : i64 + %12607 = arith.index_cast %12606 : i64 to index + %from_elements_4421 = tensor.from_elements %12607, %c4096 : tensor<2xindex> + %12608 = stablehlo.dynamic_broadcast_in_dim %12563, %from_elements_4421, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4422 = tensor.dim %12608, %c0 : tensor + %12609 = arith.index_cast %dim_4422 : index to i64 + %from_elements_4423 = tensor.from_elements %12609, %c4096_i64 : tensor<2xi64> + %12610 = stablehlo.real_dynamic_slice %12603, %c_22, %from_elements_4423, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4424 = tensor.from_elements %12609, %c4096_i64, %c1_i64 : tensor<3xi64> + %12611 = stablehlo.dynamic_reshape %12608, %from_elements_4424 : (tensor, tensor<3xi64>) -> tensor + %12612 = stablehlo.dynamic_iota %from_elements_4424, dim = 1 : (tensor<3xi64>) -> tensor + %12613 = stablehlo.concatenate %12611, %12612, dim = 2 : (tensor, tensor) -> tensor + %12614 = "stablehlo.scatter"(%12551, %12613, %12610) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12615 = stablehlo.slice %12423 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12616 = stablehlo.reshape %12615 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12617 = stablehlo.custom_call @byteir.non_zero(%12616) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4425 = tensor.dim %12617, %c0 : tensor + %12618 = arith.index_cast %dim_4425 : index to i64 + %from_elements_4426 = tensor.from_elements %12618, %c1_i64 : tensor<2xi64> + %12619 = stablehlo.real_dynamic_slice %12617, %c_22, %from_elements_4426, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4427 = tensor.dim %12619, %c0 : tensor + %12620 = arith.index_cast %dim_4427 : index to i64 + %from_elements_4428 = tensor.from_elements %12620 : tensor<1xi64> + %12621 = stablehlo.dynamic_reshape %12619, %from_elements_4428 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4429 = tensor.from_elements %12618, %c2_i64 : tensor<2xi64> + %12622 = stablehlo.real_dynamic_slice %12617, %c_24, %from_elements_4429, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4430 = tensor.dim %12622, %c0 : tensor + %12623 = arith.index_cast %dim_4430 : index to i64 + %from_elements_4431 = tensor.from_elements %12623 : tensor<1xi64> + %12624 = stablehlo.dynamic_reshape %12622, %from_elements_4431 : (tensor, tensor<1xi64>) -> tensor + %dim_4432 = tensor.dim %12624, %c0 : tensor + %12625 = arith.index_cast %dim_4432 : index to i64 + %from_elements_4433 = tensor.from_elements %12625, %c1_i64 : tensor<2xi64> + %12626 = stablehlo.dynamic_reshape %12624, %from_elements_4433 : (tensor, tensor<2xi64>) -> tensor + %dim_4434 = tensor.dim %12626, %c0 : tensor + %12627 = arith.index_cast %dim_4434 : index to i64 + %from_elements_4435 = tensor.from_elements %c1_i64, %12627, %c4096_i64 : tensor<3xi64> + %12628 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4435, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4436 = tensor.dim %12628, %c1 : tensor<1x?x4096xi64> + %12629 = arith.index_cast %dim_4436 : index to i64 + %from_elements_4437 = tensor.from_elements %c1_i64, %12629, %c4096_i64, %c1_i64 : tensor<4xi64> + %12630 = stablehlo.dynamic_reshape %12628, %from_elements_4437 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12631 = stablehlo.dynamic_broadcast_in_dim %12626, %from_elements_4435, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4438 = tensor.dim %12631, %c1 : tensor<1x?x4096xi64> + %12632 = arith.index_cast %dim_4438 : index to i64 + %from_elements_4439 = tensor.from_elements %c1_i64, %12632, %c4096_i64, %c1_i64 : tensor<4xi64> + %12633 = stablehlo.dynamic_reshape %12631, %from_elements_4439 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12634 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4435, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4440 = tensor.dim %12634, %c1 : tensor<1x?x4096xi64> + %12635 = arith.index_cast %dim_4440 : index to i64 + %from_elements_4441 = tensor.from_elements %c1_i64, %12635, %c4096_i64, %c1_i64 : tensor<4xi64> + %12636 = stablehlo.dynamic_reshape %12634, %from_elements_4441 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12637 = stablehlo.concatenate %12630, %12633, %12636, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12638 = "stablehlo.gather"(%12434, %12637) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12639 = shape.shape_of %12638 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12640 = shape.num_elements %12639 : tensor<3xindex> -> index + %12641 = stablehlo.compute_reshape_shape %12640, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12642 = stablehlo.dynamic_reshape %12638, %12641 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12643 = stablehlo.dot %12642, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12644 = stablehlo.logistic %12643 : tensor + %12645 = shape.shape_of %12644 : tensor -> tensor<2xindex> + %12646 = shape.shape_of %12643 : tensor -> tensor<2xindex> + %12647 = shape.cstr_broadcastable %12645, %12646 : tensor<2xindex>, tensor<2xindex> + %12648 = shape.assuming %12647 -> (tensor) { + %19688 = shape.broadcast %12645, %12646 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12644, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12643, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12649 = shape.shape_of %12648 : tensor -> tensor<2xindex> + %12650 = shape.cstr_broadcastable %12649, %12646 : tensor<2xindex>, tensor<2xindex> + %12651 = shape.assuming %12650 -> (tensor) { + %19688 = shape.broadcast %12649, %12646 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12648, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12643, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12652 = stablehlo.dot %12651, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4442 = tensor.dim %12624, %c0 : tensor + %12653 = arith.index_cast %dim_4442 : index to i64 + %from_elements_4443 = tensor.from_elements %12653, %c1_i64 : tensor<2xi64> + %12654 = stablehlo.dynamic_reshape %12624, %from_elements_4443 : (tensor, tensor<2xi64>) -> tensor + %dim_4444 = tensor.dim %12621, %c0 : tensor + %12655 = arith.index_cast %dim_4444 : index to i64 + %from_elements_4445 = tensor.from_elements %12655, %c1_i64 : tensor<2xi64> + %12656 = stablehlo.dynamic_reshape %12621, %from_elements_4445 : (tensor, tensor<2xi64>) -> tensor + %12657 = stablehlo.concatenate %12654, %12656, dim = 1 : (tensor, tensor) -> tensor + %12658 = "stablehlo.gather"(%12463, %12657) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12659 = shape.shape_of %12652 : tensor -> tensor<2xindex> + %12660 = shape.shape_of %12658 : tensor -> tensor<2xindex> + %12661 = shape.cstr_broadcastable %12659, %12660 : tensor<2xindex>, tensor<2xindex> + %12662 = shape.assuming %12661 -> (tensor) { + %19688 = shape.broadcast %12659, %12660 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12652, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12658, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12663 = shape.shape_of %12662 : tensor -> tensor<2xindex> + %12664 = stablehlo.dynamic_broadcast_in_dim %12662, %12663, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12665 = stablehlo.dynamic_broadcast_in_dim %213, %12663, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12666 = stablehlo.multiply %12664, %12665 : tensor + %dim_4446 = tensor.dim %12626, %c0 : tensor + %12667 = arith.index_cast %dim_4446 : index to i64 + %dim_4447 = tensor.dim %12662, %c0 : tensor + %12668 = arith.index_cast %dim_4447 : index to i64 + %12669 = arith.maxsi %12667, %12668 : i64 + %12670 = arith.index_cast %12669 : i64 to index + %from_elements_4448 = tensor.from_elements %12670, %c4096 : tensor<2xindex> + %12671 = stablehlo.dynamic_broadcast_in_dim %12626, %from_elements_4448, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4449 = tensor.dim %12671, %c0 : tensor + %12672 = arith.index_cast %dim_4449 : index to i64 + %from_elements_4450 = tensor.from_elements %12672, %c4096_i64 : tensor<2xi64> + %12673 = stablehlo.real_dynamic_slice %12666, %c_22, %from_elements_4450, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4451 = tensor.from_elements %12672, %c4096_i64, %c1_i64 : tensor<3xi64> + %12674 = stablehlo.dynamic_reshape %12671, %from_elements_4451 : (tensor, tensor<3xi64>) -> tensor + %12675 = stablehlo.dynamic_iota %from_elements_4451, dim = 1 : (tensor<3xi64>) -> tensor + %12676 = stablehlo.concatenate %12674, %12675, dim = 2 : (tensor, tensor) -> tensor + %12677 = "stablehlo.scatter"(%12614, %12676, %12673) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12678 = stablehlo.slice %12423 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12679 = stablehlo.reshape %12678 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12680 = stablehlo.custom_call @byteir.non_zero(%12679) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4452 = tensor.dim %12680, %c0 : tensor + %12681 = arith.index_cast %dim_4452 : index to i64 + %from_elements_4453 = tensor.from_elements %12681, %c1_i64 : tensor<2xi64> + %12682 = stablehlo.real_dynamic_slice %12680, %c_22, %from_elements_4453, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4454 = tensor.dim %12682, %c0 : tensor + %12683 = arith.index_cast %dim_4454 : index to i64 + %from_elements_4455 = tensor.from_elements %12683 : tensor<1xi64> + %12684 = stablehlo.dynamic_reshape %12682, %from_elements_4455 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4456 = tensor.from_elements %12681, %c2_i64 : tensor<2xi64> + %12685 = stablehlo.real_dynamic_slice %12680, %c_24, %from_elements_4456, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4457 = tensor.dim %12685, %c0 : tensor + %12686 = arith.index_cast %dim_4457 : index to i64 + %from_elements_4458 = tensor.from_elements %12686 : tensor<1xi64> + %12687 = stablehlo.dynamic_reshape %12685, %from_elements_4458 : (tensor, tensor<1xi64>) -> tensor + %dim_4459 = tensor.dim %12687, %c0 : tensor + %12688 = arith.index_cast %dim_4459 : index to i64 + %from_elements_4460 = tensor.from_elements %12688, %c1_i64 : tensor<2xi64> + %12689 = stablehlo.dynamic_reshape %12687, %from_elements_4460 : (tensor, tensor<2xi64>) -> tensor + %dim_4461 = tensor.dim %12689, %c0 : tensor + %12690 = arith.index_cast %dim_4461 : index to i64 + %from_elements_4462 = tensor.from_elements %c1_i64, %12690, %c4096_i64 : tensor<3xi64> + %12691 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4462, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4463 = tensor.dim %12691, %c1 : tensor<1x?x4096xi64> + %12692 = arith.index_cast %dim_4463 : index to i64 + %from_elements_4464 = tensor.from_elements %c1_i64, %12692, %c4096_i64, %c1_i64 : tensor<4xi64> + %12693 = stablehlo.dynamic_reshape %12691, %from_elements_4464 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12694 = stablehlo.dynamic_broadcast_in_dim %12689, %from_elements_4462, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4465 = tensor.dim %12694, %c1 : tensor<1x?x4096xi64> + %12695 = arith.index_cast %dim_4465 : index to i64 + %from_elements_4466 = tensor.from_elements %c1_i64, %12695, %c4096_i64, %c1_i64 : tensor<4xi64> + %12696 = stablehlo.dynamic_reshape %12694, %from_elements_4466 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12697 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4462, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4467 = tensor.dim %12697, %c1 : tensor<1x?x4096xi64> + %12698 = arith.index_cast %dim_4467 : index to i64 + %from_elements_4468 = tensor.from_elements %c1_i64, %12698, %c4096_i64, %c1_i64 : tensor<4xi64> + %12699 = stablehlo.dynamic_reshape %12697, %from_elements_4468 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12700 = stablehlo.concatenate %12693, %12696, %12699, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12701 = "stablehlo.gather"(%12434, %12700) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12702 = shape.shape_of %12701 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12703 = shape.num_elements %12702 : tensor<3xindex> -> index + %12704 = stablehlo.compute_reshape_shape %12703, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12705 = stablehlo.dynamic_reshape %12701, %12704 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12706 = stablehlo.dot %12705, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12707 = stablehlo.logistic %12706 : tensor + %12708 = shape.shape_of %12707 : tensor -> tensor<2xindex> + %12709 = shape.shape_of %12706 : tensor -> tensor<2xindex> + %12710 = shape.cstr_broadcastable %12708, %12709 : tensor<2xindex>, tensor<2xindex> + %12711 = shape.assuming %12710 -> (tensor) { + %19688 = shape.broadcast %12708, %12709 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12707, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12706, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12712 = shape.shape_of %12711 : tensor -> tensor<2xindex> + %12713 = shape.cstr_broadcastable %12712, %12709 : tensor<2xindex>, tensor<2xindex> + %12714 = shape.assuming %12713 -> (tensor) { + %19688 = shape.broadcast %12712, %12709 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12711, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12706, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12715 = stablehlo.dot %12714, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4469 = tensor.dim %12687, %c0 : tensor + %12716 = arith.index_cast %dim_4469 : index to i64 + %from_elements_4470 = tensor.from_elements %12716, %c1_i64 : tensor<2xi64> + %12717 = stablehlo.dynamic_reshape %12687, %from_elements_4470 : (tensor, tensor<2xi64>) -> tensor + %dim_4471 = tensor.dim %12684, %c0 : tensor + %12718 = arith.index_cast %dim_4471 : index to i64 + %from_elements_4472 = tensor.from_elements %12718, %c1_i64 : tensor<2xi64> + %12719 = stablehlo.dynamic_reshape %12684, %from_elements_4472 : (tensor, tensor<2xi64>) -> tensor + %12720 = stablehlo.concatenate %12717, %12719, dim = 1 : (tensor, tensor) -> tensor + %12721 = "stablehlo.gather"(%12463, %12720) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12722 = shape.shape_of %12715 : tensor -> tensor<2xindex> + %12723 = shape.shape_of %12721 : tensor -> tensor<2xindex> + %12724 = shape.cstr_broadcastable %12722, %12723 : tensor<2xindex>, tensor<2xindex> + %12725 = shape.assuming %12724 -> (tensor) { + %19688 = shape.broadcast %12722, %12723 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12715, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12721, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12726 = shape.shape_of %12725 : tensor -> tensor<2xindex> + %12727 = stablehlo.dynamic_broadcast_in_dim %12725, %12726, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12728 = stablehlo.dynamic_broadcast_in_dim %213, %12726, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12729 = stablehlo.multiply %12727, %12728 : tensor + %dim_4473 = tensor.dim %12689, %c0 : tensor + %12730 = arith.index_cast %dim_4473 : index to i64 + %dim_4474 = tensor.dim %12725, %c0 : tensor + %12731 = arith.index_cast %dim_4474 : index to i64 + %12732 = arith.maxsi %12730, %12731 : i64 + %12733 = arith.index_cast %12732 : i64 to index + %from_elements_4475 = tensor.from_elements %12733, %c4096 : tensor<2xindex> + %12734 = stablehlo.dynamic_broadcast_in_dim %12689, %from_elements_4475, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4476 = tensor.dim %12734, %c0 : tensor + %12735 = arith.index_cast %dim_4476 : index to i64 + %from_elements_4477 = tensor.from_elements %12735, %c4096_i64 : tensor<2xi64> + %12736 = stablehlo.real_dynamic_slice %12729, %c_22, %from_elements_4477, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4478 = tensor.from_elements %12735, %c4096_i64, %c1_i64 : tensor<3xi64> + %12737 = stablehlo.dynamic_reshape %12734, %from_elements_4478 : (tensor, tensor<3xi64>) -> tensor + %12738 = stablehlo.dynamic_iota %from_elements_4478, dim = 1 : (tensor<3xi64>) -> tensor + %12739 = stablehlo.concatenate %12737, %12738, dim = 2 : (tensor, tensor) -> tensor + %12740 = "stablehlo.scatter"(%12677, %12739, %12736) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12741 = stablehlo.slice %12423 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12742 = stablehlo.reshape %12741 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12743 = stablehlo.custom_call @byteir.non_zero(%12742) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4479 = tensor.dim %12743, %c0 : tensor + %12744 = arith.index_cast %dim_4479 : index to i64 + %from_elements_4480 = tensor.from_elements %12744, %c1_i64 : tensor<2xi64> + %12745 = stablehlo.real_dynamic_slice %12743, %c_22, %from_elements_4480, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4481 = tensor.dim %12745, %c0 : tensor + %12746 = arith.index_cast %dim_4481 : index to i64 + %from_elements_4482 = tensor.from_elements %12746 : tensor<1xi64> + %12747 = stablehlo.dynamic_reshape %12745, %from_elements_4482 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4483 = tensor.from_elements %12744, %c2_i64 : tensor<2xi64> + %12748 = stablehlo.real_dynamic_slice %12743, %c_24, %from_elements_4483, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4484 = tensor.dim %12748, %c0 : tensor + %12749 = arith.index_cast %dim_4484 : index to i64 + %from_elements_4485 = tensor.from_elements %12749 : tensor<1xi64> + %12750 = stablehlo.dynamic_reshape %12748, %from_elements_4485 : (tensor, tensor<1xi64>) -> tensor + %dim_4486 = tensor.dim %12750, %c0 : tensor + %12751 = arith.index_cast %dim_4486 : index to i64 + %from_elements_4487 = tensor.from_elements %12751, %c1_i64 : tensor<2xi64> + %12752 = stablehlo.dynamic_reshape %12750, %from_elements_4487 : (tensor, tensor<2xi64>) -> tensor + %dim_4488 = tensor.dim %12752, %c0 : tensor + %12753 = arith.index_cast %dim_4488 : index to i64 + %from_elements_4489 = tensor.from_elements %c1_i64, %12753, %c4096_i64 : tensor<3xi64> + %12754 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4489, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4490 = tensor.dim %12754, %c1 : tensor<1x?x4096xi64> + %12755 = arith.index_cast %dim_4490 : index to i64 + %from_elements_4491 = tensor.from_elements %c1_i64, %12755, %c4096_i64, %c1_i64 : tensor<4xi64> + %12756 = stablehlo.dynamic_reshape %12754, %from_elements_4491 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12757 = stablehlo.dynamic_broadcast_in_dim %12752, %from_elements_4489, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4492 = tensor.dim %12757, %c1 : tensor<1x?x4096xi64> + %12758 = arith.index_cast %dim_4492 : index to i64 + %from_elements_4493 = tensor.from_elements %c1_i64, %12758, %c4096_i64, %c1_i64 : tensor<4xi64> + %12759 = stablehlo.dynamic_reshape %12757, %from_elements_4493 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12760 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4489, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4494 = tensor.dim %12760, %c1 : tensor<1x?x4096xi64> + %12761 = arith.index_cast %dim_4494 : index to i64 + %from_elements_4495 = tensor.from_elements %c1_i64, %12761, %c4096_i64, %c1_i64 : tensor<4xi64> + %12762 = stablehlo.dynamic_reshape %12760, %from_elements_4495 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12763 = stablehlo.concatenate %12756, %12759, %12762, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12764 = "stablehlo.gather"(%12434, %12763) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12765 = shape.shape_of %12764 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12766 = shape.num_elements %12765 : tensor<3xindex> -> index + %12767 = stablehlo.compute_reshape_shape %12766, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12768 = stablehlo.dynamic_reshape %12764, %12767 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12769 = stablehlo.dot %12768, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12770 = stablehlo.logistic %12769 : tensor + %12771 = shape.shape_of %12770 : tensor -> tensor<2xindex> + %12772 = shape.shape_of %12769 : tensor -> tensor<2xindex> + %12773 = shape.cstr_broadcastable %12771, %12772 : tensor<2xindex>, tensor<2xindex> + %12774 = shape.assuming %12773 -> (tensor) { + %19688 = shape.broadcast %12771, %12772 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12770, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12769, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12775 = shape.shape_of %12774 : tensor -> tensor<2xindex> + %12776 = shape.cstr_broadcastable %12775, %12772 : tensor<2xindex>, tensor<2xindex> + %12777 = shape.assuming %12776 -> (tensor) { + %19688 = shape.broadcast %12775, %12772 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12774, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12769, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12778 = stablehlo.dot %12777, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4496 = tensor.dim %12750, %c0 : tensor + %12779 = arith.index_cast %dim_4496 : index to i64 + %from_elements_4497 = tensor.from_elements %12779, %c1_i64 : tensor<2xi64> + %12780 = stablehlo.dynamic_reshape %12750, %from_elements_4497 : (tensor, tensor<2xi64>) -> tensor + %dim_4498 = tensor.dim %12747, %c0 : tensor + %12781 = arith.index_cast %dim_4498 : index to i64 + %from_elements_4499 = tensor.from_elements %12781, %c1_i64 : tensor<2xi64> + %12782 = stablehlo.dynamic_reshape %12747, %from_elements_4499 : (tensor, tensor<2xi64>) -> tensor + %12783 = stablehlo.concatenate %12780, %12782, dim = 1 : (tensor, tensor) -> tensor + %12784 = "stablehlo.gather"(%12463, %12783) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12785 = shape.shape_of %12778 : tensor -> tensor<2xindex> + %12786 = shape.shape_of %12784 : tensor -> tensor<2xindex> + %12787 = shape.cstr_broadcastable %12785, %12786 : tensor<2xindex>, tensor<2xindex> + %12788 = shape.assuming %12787 -> (tensor) { + %19688 = shape.broadcast %12785, %12786 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12778, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12784, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12789 = shape.shape_of %12788 : tensor -> tensor<2xindex> + %12790 = stablehlo.dynamic_broadcast_in_dim %12788, %12789, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12791 = stablehlo.dynamic_broadcast_in_dim %213, %12789, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12792 = stablehlo.multiply %12790, %12791 : tensor + %dim_4500 = tensor.dim %12752, %c0 : tensor + %12793 = arith.index_cast %dim_4500 : index to i64 + %dim_4501 = tensor.dim %12788, %c0 : tensor + %12794 = arith.index_cast %dim_4501 : index to i64 + %12795 = arith.maxsi %12793, %12794 : i64 + %12796 = arith.index_cast %12795 : i64 to index + %from_elements_4502 = tensor.from_elements %12796, %c4096 : tensor<2xindex> + %12797 = stablehlo.dynamic_broadcast_in_dim %12752, %from_elements_4502, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4503 = tensor.dim %12797, %c0 : tensor + %12798 = arith.index_cast %dim_4503 : index to i64 + %from_elements_4504 = tensor.from_elements %12798, %c4096_i64 : tensor<2xi64> + %12799 = stablehlo.real_dynamic_slice %12792, %c_22, %from_elements_4504, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4505 = tensor.from_elements %12798, %c4096_i64, %c1_i64 : tensor<3xi64> + %12800 = stablehlo.dynamic_reshape %12797, %from_elements_4505 : (tensor, tensor<3xi64>) -> tensor + %12801 = stablehlo.dynamic_iota %from_elements_4505, dim = 1 : (tensor<3xi64>) -> tensor + %12802 = stablehlo.concatenate %12800, %12801, dim = 2 : (tensor, tensor) -> tensor + %12803 = "stablehlo.scatter"(%12740, %12802, %12799) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12804 = stablehlo.slice %12423 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12805 = stablehlo.reshape %12804 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12806 = stablehlo.custom_call @byteir.non_zero(%12805) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4506 = tensor.dim %12806, %c0 : tensor + %12807 = arith.index_cast %dim_4506 : index to i64 + %from_elements_4507 = tensor.from_elements %12807, %c1_i64 : tensor<2xi64> + %12808 = stablehlo.real_dynamic_slice %12806, %c_22, %from_elements_4507, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4508 = tensor.dim %12808, %c0 : tensor + %12809 = arith.index_cast %dim_4508 : index to i64 + %from_elements_4509 = tensor.from_elements %12809 : tensor<1xi64> + %12810 = stablehlo.dynamic_reshape %12808, %from_elements_4509 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4510 = tensor.from_elements %12807, %c2_i64 : tensor<2xi64> + %12811 = stablehlo.real_dynamic_slice %12806, %c_24, %from_elements_4510, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4511 = tensor.dim %12811, %c0 : tensor + %12812 = arith.index_cast %dim_4511 : index to i64 + %from_elements_4512 = tensor.from_elements %12812 : tensor<1xi64> + %12813 = stablehlo.dynamic_reshape %12811, %from_elements_4512 : (tensor, tensor<1xi64>) -> tensor + %dim_4513 = tensor.dim %12813, %c0 : tensor + %12814 = arith.index_cast %dim_4513 : index to i64 + %from_elements_4514 = tensor.from_elements %12814, %c1_i64 : tensor<2xi64> + %12815 = stablehlo.dynamic_reshape %12813, %from_elements_4514 : (tensor, tensor<2xi64>) -> tensor + %dim_4515 = tensor.dim %12815, %c0 : tensor + %12816 = arith.index_cast %dim_4515 : index to i64 + %from_elements_4516 = tensor.from_elements %c1_i64, %12816, %c4096_i64 : tensor<3xi64> + %12817 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4516, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4517 = tensor.dim %12817, %c1 : tensor<1x?x4096xi64> + %12818 = arith.index_cast %dim_4517 : index to i64 + %from_elements_4518 = tensor.from_elements %c1_i64, %12818, %c4096_i64, %c1_i64 : tensor<4xi64> + %12819 = stablehlo.dynamic_reshape %12817, %from_elements_4518 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12820 = stablehlo.dynamic_broadcast_in_dim %12815, %from_elements_4516, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4519 = tensor.dim %12820, %c1 : tensor<1x?x4096xi64> + %12821 = arith.index_cast %dim_4519 : index to i64 + %from_elements_4520 = tensor.from_elements %c1_i64, %12821, %c4096_i64, %c1_i64 : tensor<4xi64> + %12822 = stablehlo.dynamic_reshape %12820, %from_elements_4520 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12823 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4516, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4521 = tensor.dim %12823, %c1 : tensor<1x?x4096xi64> + %12824 = arith.index_cast %dim_4521 : index to i64 + %from_elements_4522 = tensor.from_elements %c1_i64, %12824, %c4096_i64, %c1_i64 : tensor<4xi64> + %12825 = stablehlo.dynamic_reshape %12823, %from_elements_4522 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12826 = stablehlo.concatenate %12819, %12822, %12825, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12827 = "stablehlo.gather"(%12434, %12826) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12828 = shape.shape_of %12827 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12829 = shape.num_elements %12828 : tensor<3xindex> -> index + %12830 = stablehlo.compute_reshape_shape %12829, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12831 = stablehlo.dynamic_reshape %12827, %12830 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12832 = stablehlo.dot %12831, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12833 = stablehlo.logistic %12832 : tensor + %12834 = shape.shape_of %12833 : tensor -> tensor<2xindex> + %12835 = shape.shape_of %12832 : tensor -> tensor<2xindex> + %12836 = shape.cstr_broadcastable %12834, %12835 : tensor<2xindex>, tensor<2xindex> + %12837 = shape.assuming %12836 -> (tensor) { + %19688 = shape.broadcast %12834, %12835 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12833, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12832, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12838 = shape.shape_of %12837 : tensor -> tensor<2xindex> + %12839 = shape.cstr_broadcastable %12838, %12835 : tensor<2xindex>, tensor<2xindex> + %12840 = shape.assuming %12839 -> (tensor) { + %19688 = shape.broadcast %12838, %12835 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12837, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12832, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12841 = stablehlo.dot %12840, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4523 = tensor.dim %12813, %c0 : tensor + %12842 = arith.index_cast %dim_4523 : index to i64 + %from_elements_4524 = tensor.from_elements %12842, %c1_i64 : tensor<2xi64> + %12843 = stablehlo.dynamic_reshape %12813, %from_elements_4524 : (tensor, tensor<2xi64>) -> tensor + %dim_4525 = tensor.dim %12810, %c0 : tensor + %12844 = arith.index_cast %dim_4525 : index to i64 + %from_elements_4526 = tensor.from_elements %12844, %c1_i64 : tensor<2xi64> + %12845 = stablehlo.dynamic_reshape %12810, %from_elements_4526 : (tensor, tensor<2xi64>) -> tensor + %12846 = stablehlo.concatenate %12843, %12845, dim = 1 : (tensor, tensor) -> tensor + %12847 = "stablehlo.gather"(%12463, %12846) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12848 = shape.shape_of %12841 : tensor -> tensor<2xindex> + %12849 = shape.shape_of %12847 : tensor -> tensor<2xindex> + %12850 = shape.cstr_broadcastable %12848, %12849 : tensor<2xindex>, tensor<2xindex> + %12851 = shape.assuming %12850 -> (tensor) { + %19688 = shape.broadcast %12848, %12849 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12841, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12847, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12852 = shape.shape_of %12851 : tensor -> tensor<2xindex> + %12853 = stablehlo.dynamic_broadcast_in_dim %12851, %12852, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12854 = stablehlo.dynamic_broadcast_in_dim %213, %12852, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12855 = stablehlo.multiply %12853, %12854 : tensor + %dim_4527 = tensor.dim %12815, %c0 : tensor + %12856 = arith.index_cast %dim_4527 : index to i64 + %dim_4528 = tensor.dim %12851, %c0 : tensor + %12857 = arith.index_cast %dim_4528 : index to i64 + %12858 = arith.maxsi %12856, %12857 : i64 + %12859 = arith.index_cast %12858 : i64 to index + %from_elements_4529 = tensor.from_elements %12859, %c4096 : tensor<2xindex> + %12860 = stablehlo.dynamic_broadcast_in_dim %12815, %from_elements_4529, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4530 = tensor.dim %12860, %c0 : tensor + %12861 = arith.index_cast %dim_4530 : index to i64 + %from_elements_4531 = tensor.from_elements %12861, %c4096_i64 : tensor<2xi64> + %12862 = stablehlo.real_dynamic_slice %12855, %c_22, %from_elements_4531, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4532 = tensor.from_elements %12861, %c4096_i64, %c1_i64 : tensor<3xi64> + %12863 = stablehlo.dynamic_reshape %12860, %from_elements_4532 : (tensor, tensor<3xi64>) -> tensor + %12864 = stablehlo.dynamic_iota %from_elements_4532, dim = 1 : (tensor<3xi64>) -> tensor + %12865 = stablehlo.concatenate %12863, %12864, dim = 2 : (tensor, tensor) -> tensor + %12866 = "stablehlo.scatter"(%12803, %12865, %12862) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12867 = stablehlo.slice %12423 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %12868 = stablehlo.reshape %12867 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %12869 = stablehlo.custom_call @byteir.non_zero(%12868) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4533 = tensor.dim %12869, %c0 : tensor + %12870 = arith.index_cast %dim_4533 : index to i64 + %from_elements_4534 = tensor.from_elements %12870, %c1_i64 : tensor<2xi64> + %12871 = stablehlo.real_dynamic_slice %12869, %c_22, %from_elements_4534, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4535 = tensor.dim %12871, %c0 : tensor + %12872 = arith.index_cast %dim_4535 : index to i64 + %from_elements_4536 = tensor.from_elements %12872 : tensor<1xi64> + %12873 = stablehlo.dynamic_reshape %12871, %from_elements_4536 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4537 = tensor.from_elements %12870, %c2_i64 : tensor<2xi64> + %12874 = stablehlo.real_dynamic_slice %12869, %c_24, %from_elements_4537, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4538 = tensor.dim %12874, %c0 : tensor + %12875 = arith.index_cast %dim_4538 : index to i64 + %from_elements_4539 = tensor.from_elements %12875 : tensor<1xi64> + %12876 = stablehlo.dynamic_reshape %12874, %from_elements_4539 : (tensor, tensor<1xi64>) -> tensor + %dim_4540 = tensor.dim %12876, %c0 : tensor + %12877 = arith.index_cast %dim_4540 : index to i64 + %from_elements_4541 = tensor.from_elements %12877, %c1_i64 : tensor<2xi64> + %12878 = stablehlo.dynamic_reshape %12876, %from_elements_4541 : (tensor, tensor<2xi64>) -> tensor + %dim_4542 = tensor.dim %12878, %c0 : tensor + %12879 = arith.index_cast %dim_4542 : index to i64 + %from_elements_4543 = tensor.from_elements %c1_i64, %12879, %c4096_i64 : tensor<3xi64> + %12880 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4543, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4544 = tensor.dim %12880, %c1 : tensor<1x?x4096xi64> + %12881 = arith.index_cast %dim_4544 : index to i64 + %from_elements_4545 = tensor.from_elements %c1_i64, %12881, %c4096_i64, %c1_i64 : tensor<4xi64> + %12882 = stablehlo.dynamic_reshape %12880, %from_elements_4545 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12883 = stablehlo.dynamic_broadcast_in_dim %12878, %from_elements_4543, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4546 = tensor.dim %12883, %c1 : tensor<1x?x4096xi64> + %12884 = arith.index_cast %dim_4546 : index to i64 + %from_elements_4547 = tensor.from_elements %c1_i64, %12884, %c4096_i64, %c1_i64 : tensor<4xi64> + %12885 = stablehlo.dynamic_reshape %12883, %from_elements_4547 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12886 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4543, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4548 = tensor.dim %12886, %c1 : tensor<1x?x4096xi64> + %12887 = arith.index_cast %dim_4548 : index to i64 + %from_elements_4549 = tensor.from_elements %c1_i64, %12887, %c4096_i64, %c1_i64 : tensor<4xi64> + %12888 = stablehlo.dynamic_reshape %12886, %from_elements_4549 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %12889 = stablehlo.concatenate %12882, %12885, %12888, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %12890 = "stablehlo.gather"(%12434, %12889) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %12891 = shape.shape_of %12890 : tensor<1x?x4096xf32> -> tensor<3xindex> + %12892 = shape.num_elements %12891 : tensor<3xindex> -> index + %12893 = stablehlo.compute_reshape_shape %12892, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %12894 = stablehlo.dynamic_reshape %12890, %12893 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %12895 = stablehlo.dot %12894, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %12896 = stablehlo.logistic %12895 : tensor + %12897 = shape.shape_of %12896 : tensor -> tensor<2xindex> + %12898 = shape.shape_of %12895 : tensor -> tensor<2xindex> + %12899 = shape.cstr_broadcastable %12897, %12898 : tensor<2xindex>, tensor<2xindex> + %12900 = shape.assuming %12899 -> (tensor) { + %19688 = shape.broadcast %12897, %12898 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12896, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12895, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12901 = shape.shape_of %12900 : tensor -> tensor<2xindex> + %12902 = shape.cstr_broadcastable %12901, %12898 : tensor<2xindex>, tensor<2xindex> + %12903 = shape.assuming %12902 -> (tensor) { + %19688 = shape.broadcast %12901, %12898 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12900, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12895, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12904 = stablehlo.dot %12903, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4550 = tensor.dim %12876, %c0 : tensor + %12905 = arith.index_cast %dim_4550 : index to i64 + %from_elements_4551 = tensor.from_elements %12905, %c1_i64 : tensor<2xi64> + %12906 = stablehlo.dynamic_reshape %12876, %from_elements_4551 : (tensor, tensor<2xi64>) -> tensor + %dim_4552 = tensor.dim %12873, %c0 : tensor + %12907 = arith.index_cast %dim_4552 : index to i64 + %from_elements_4553 = tensor.from_elements %12907, %c1_i64 : tensor<2xi64> + %12908 = stablehlo.dynamic_reshape %12873, %from_elements_4553 : (tensor, tensor<2xi64>) -> tensor + %12909 = stablehlo.concatenate %12906, %12908, dim = 1 : (tensor, tensor) -> tensor + %12910 = "stablehlo.gather"(%12463, %12909) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %12911 = shape.shape_of %12904 : tensor -> tensor<2xindex> + %12912 = shape.shape_of %12910 : tensor -> tensor<2xindex> + %12913 = shape.cstr_broadcastable %12911, %12912 : tensor<2xindex>, tensor<2xindex> + %12914 = shape.assuming %12913 -> (tensor) { + %19688 = shape.broadcast %12911, %12912 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %12904, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %12910, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %12915 = shape.shape_of %12914 : tensor -> tensor<2xindex> + %12916 = stablehlo.dynamic_broadcast_in_dim %12914, %12915, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %12917 = stablehlo.dynamic_broadcast_in_dim %213, %12915, dims = [] : (tensor, tensor<2xindex>) -> tensor + %12918 = stablehlo.multiply %12916, %12917 : tensor + %dim_4554 = tensor.dim %12878, %c0 : tensor + %12919 = arith.index_cast %dim_4554 : index to i64 + %dim_4555 = tensor.dim %12914, %c0 : tensor + %12920 = arith.index_cast %dim_4555 : index to i64 + %12921 = arith.maxsi %12919, %12920 : i64 + %12922 = arith.index_cast %12921 : i64 to index + %from_elements_4556 = tensor.from_elements %12922, %c4096 : tensor<2xindex> + %12923 = stablehlo.dynamic_broadcast_in_dim %12878, %from_elements_4556, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4557 = tensor.dim %12923, %c0 : tensor + %12924 = arith.index_cast %dim_4557 : index to i64 + %from_elements_4558 = tensor.from_elements %12924, %c4096_i64 : tensor<2xi64> + %12925 = stablehlo.real_dynamic_slice %12918, %c_22, %from_elements_4558, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4559 = tensor.from_elements %12924, %c4096_i64, %c1_i64 : tensor<3xi64> + %12926 = stablehlo.dynamic_reshape %12923, %from_elements_4559 : (tensor, tensor<3xi64>) -> tensor + %12927 = stablehlo.dynamic_iota %from_elements_4559, dim = 1 : (tensor<3xi64>) -> tensor + %12928 = stablehlo.concatenate %12926, %12927, dim = 2 : (tensor, tensor) -> tensor + %12929 = "stablehlo.scatter"(%12866, %12928, %12925) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %12930 = stablehlo.reshape %12929 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %12931 = stablehlo.add %12396, %12930 : tensor<3x1x4096xf32> + %12932 = stablehlo.broadcast_in_dim %12931, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %12933 = stablehlo.power %12932, %15 : tensor<3x1x4096xf32> + %12934 = stablehlo.reduce(%12933 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %12935 = stablehlo.reshape %12934 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %12936 = stablehlo.broadcast_in_dim %12935, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %12937 = stablehlo.divide %12936, %21 : tensor<3x1x1xf32> + %12938 = stablehlo.broadcast_in_dim %12937, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %12939 = stablehlo.add %12938, %25 : tensor<3x1x1xf32> + %12940 = stablehlo.rsqrt %12939 : tensor<3x1x1xf32> + %12941 = stablehlo.broadcast_in_dim %12940, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %12942 = stablehlo.multiply %12932, %12941 : tensor<3x1x4096xf32> + %12943 = stablehlo.broadcast_in_dim %12942, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %12944 = stablehlo.multiply %12943, %31 : tensor<3x1x4096xf32> + %12945 = stablehlo.reshape %12944 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %12946 = stablehlo.dot %12945, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %12947 = stablehlo.reshape %12946 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %12948 = stablehlo.dot %12945, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %12949 = stablehlo.reshape %12948 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %12950 = stablehlo.reshape %12947 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %12951 = stablehlo.transpose %12950, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %12952 = stablehlo.reshape %12949 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %12953 = stablehlo.transpose %12952, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %12954 = stablehlo.slice %arg42 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %12955 = stablehlo.slice %arg43 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %12956 = "stablehlo.gather"(%12954, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %12957 = stablehlo.reshape %12956 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %12958 = "stablehlo.gather"(%12955, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %12959 = stablehlo.reshape %12958 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %12960 = stablehlo.broadcast_in_dim %12951, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %12961 = stablehlo.broadcast_in_dim %12957, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %12962 = stablehlo.multiply %12960, %12961 : tensor<3x32x1x128xf32> + %12963 = stablehlo.slice %12951 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %12964 = stablehlo.slice %12951 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %12965 = stablehlo.negate %12964 : tensor<3x32x1x64xf32> + %12966 = stablehlo.concatenate %12965, %12963, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %12967 = stablehlo.broadcast_in_dim %12966, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %12968 = stablehlo.broadcast_in_dim %12959, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %12969 = stablehlo.multiply %12967, %12968 : tensor<3x32x1x128xf32> + %12970 = stablehlo.add %12962, %12969 : tensor<3x32x1x128xf32> + %12971 = stablehlo.broadcast_in_dim %12953, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %12972 = stablehlo.broadcast_in_dim %12957, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %12973 = stablehlo.multiply %12971, %12972 : tensor<3x8x1x128xf32> + %12974 = stablehlo.slice %12953 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %12975 = stablehlo.slice %12953 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %12976 = stablehlo.negate %12975 : tensor<3x8x1x64xf32> + %12977 = stablehlo.concatenate %12976, %12974, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %12978 = stablehlo.broadcast_in_dim %12977, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %12979 = stablehlo.broadcast_in_dim %12959, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %12980 = stablehlo.multiply %12978, %12979 : tensor<3x8x1x128xf32> + %12981 = stablehlo.add %12973, %12980 : tensor<3x8x1x128xf32> + %12982 = stablehlo.concatenate %arg107, %12981, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %12983 = stablehlo.concatenate %arg108, %12953, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %12984 = stablehlo.reshape %12982 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %12985 = stablehlo.broadcast_in_dim %12984, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %12986 = stablehlo.reshape %12985 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %12987 = stablehlo.reshape %12983 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %12988 = stablehlo.broadcast_in_dim %12987, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %12989 = stablehlo.reshape %12988 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %12990 = stablehlo.transpose %12986, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %12991 = stablehlo.reshape %12970 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %12992 = stablehlo.reshape %12990 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %12993 = stablehlo.broadcast_in_dim %12992, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %12994 = stablehlo.dot_general %12991, %12993, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %12995 = stablehlo.reshape %12994 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %12996 = stablehlo.broadcast_in_dim %12995, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %12997 = stablehlo.divide %12996, %89 : tensor<3x32x1x8xf32> + %12998 = stablehlo.custom_call @byteir.softmax(%12997) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %12999 = stablehlo.reshape %12998 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %13000 = stablehlo.reshape %12989 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %13001 = stablehlo.broadcast_in_dim %13000, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %13002 = stablehlo.dot_general %12999, %13001, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %13003 = stablehlo.reshape %13002 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %13004 = stablehlo.transpose %13003, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %13005 = stablehlo.reshape %13004 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %13006 = stablehlo.reshape %13005 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %13007 = stablehlo.dot %13006, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %13008 = stablehlo.reshape %13007 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %13009 = stablehlo.add %12931, %13008 : tensor<3x1x4096xf32> + %13010 = stablehlo.broadcast_in_dim %13009, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %13011 = stablehlo.power %13010, %15 : tensor<3x1x4096xf32> + %13012 = stablehlo.reduce(%13011 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %13013 = stablehlo.reshape %13012 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %13014 = stablehlo.broadcast_in_dim %13013, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %13015 = stablehlo.divide %13014, %21 : tensor<3x1x1xf32> + %13016 = stablehlo.broadcast_in_dim %13015, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %13017 = stablehlo.add %13016, %25 : tensor<3x1x1xf32> + %13018 = stablehlo.rsqrt %13017 : tensor<3x1x1xf32> + %13019 = stablehlo.broadcast_in_dim %13018, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %13020 = stablehlo.multiply %13010, %13019 : tensor<3x1x4096xf32> + %13021 = stablehlo.broadcast_in_dim %13020, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %13022 = stablehlo.multiply %13021, %31 : tensor<3x1x4096xf32> + %13023 = stablehlo.reshape %13022 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %13024 = stablehlo.dot %13023, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %13025 = stablehlo.custom_call @byteir.softmax(%13024) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %13026:2 = stablehlo.custom_call @byteir.top_k(%13025) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %13027 = stablehlo.reduce(%13026#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %13028 = stablehlo.reshape %13027 : (tensor<3xf32>) -> tensor<3x1xf32> + %13029 = stablehlo.broadcast_in_dim %13026#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %13030 = stablehlo.broadcast_in_dim %13028, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %13031 = stablehlo.divide %13029, %13030 : tensor<3x2xf32> + %13032 = stablehlo.reshape %13026#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %13033 = stablehlo.broadcast_in_dim %13032, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %13034 = stablehlo.compare EQ, %13033, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %13035 = stablehlo.convert %13034 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %13036 = stablehlo.transpose %13035, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %13037 = stablehlo.slice %13036 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13038 = stablehlo.reshape %13037 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13039 = stablehlo.custom_call @byteir.non_zero(%13038) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4560 = tensor.dim %13039, %c0 : tensor + %13040 = arith.index_cast %dim_4560 : index to i64 + %from_elements_4561 = tensor.from_elements %13040, %c1_i64 : tensor<2xi64> + %13041 = stablehlo.real_dynamic_slice %13039, %c_22, %from_elements_4561, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4562 = tensor.dim %13041, %c0 : tensor + %13042 = arith.index_cast %dim_4562 : index to i64 + %from_elements_4563 = tensor.from_elements %13042 : tensor<1xi64> + %13043 = stablehlo.dynamic_reshape %13041, %from_elements_4563 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4564 = tensor.from_elements %13040, %c2_i64 : tensor<2xi64> + %13044 = stablehlo.real_dynamic_slice %13039, %c_24, %from_elements_4564, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4565 = tensor.dim %13044, %c0 : tensor + %13045 = arith.index_cast %dim_4565 : index to i64 + %from_elements_4566 = tensor.from_elements %13045 : tensor<1xi64> + %13046 = stablehlo.dynamic_reshape %13044, %from_elements_4566 : (tensor, tensor<1xi64>) -> tensor + %13047 = stablehlo.reshape %13023 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_4567 = tensor.dim %13046, %c0 : tensor + %13048 = arith.index_cast %dim_4567 : index to i64 + %from_elements_4568 = tensor.from_elements %13048, %c1_i64 : tensor<2xi64> + %13049 = stablehlo.dynamic_reshape %13046, %from_elements_4568 : (tensor, tensor<2xi64>) -> tensor + %dim_4569 = tensor.dim %13049, %c0 : tensor + %13050 = arith.index_cast %dim_4569 : index to i64 + %from_elements_4570 = tensor.from_elements %c1_i64, %13050, %c4096_i64 : tensor<3xi64> + %13051 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4570, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4571 = tensor.dim %13051, %c1 : tensor<1x?x4096xi64> + %13052 = arith.index_cast %dim_4571 : index to i64 + %from_elements_4572 = tensor.from_elements %c1_i64, %13052, %c4096_i64, %c1_i64 : tensor<4xi64> + %13053 = stablehlo.dynamic_reshape %13051, %from_elements_4572 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13054 = stablehlo.dynamic_broadcast_in_dim %13049, %from_elements_4570, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4573 = tensor.dim %13054, %c1 : tensor<1x?x4096xi64> + %13055 = arith.index_cast %dim_4573 : index to i64 + %from_elements_4574 = tensor.from_elements %c1_i64, %13055, %c4096_i64, %c1_i64 : tensor<4xi64> + %13056 = stablehlo.dynamic_reshape %13054, %from_elements_4574 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13057 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4570, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4575 = tensor.dim %13057, %c1 : tensor<1x?x4096xi64> + %13058 = arith.index_cast %dim_4575 : index to i64 + %from_elements_4576 = tensor.from_elements %c1_i64, %13058, %c4096_i64, %c1_i64 : tensor<4xi64> + %13059 = stablehlo.dynamic_reshape %13057, %from_elements_4576 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13060 = stablehlo.concatenate %13053, %13056, %13059, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13061 = "stablehlo.gather"(%13047, %13060) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13062 = shape.shape_of %13061 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13063 = shape.num_elements %13062 : tensor<3xindex> -> index + %13064 = stablehlo.compute_reshape_shape %13063, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13065 = stablehlo.dynamic_reshape %13061, %13064 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13066 = stablehlo.dot %13065, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13067 = stablehlo.logistic %13066 : tensor + %13068 = shape.shape_of %13067 : tensor -> tensor<2xindex> + %13069 = shape.shape_of %13066 : tensor -> tensor<2xindex> + %13070 = shape.cstr_broadcastable %13068, %13069 : tensor<2xindex>, tensor<2xindex> + %13071 = shape.assuming %13070 -> (tensor) { + %19688 = shape.broadcast %13068, %13069 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13067, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13066, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13072 = shape.shape_of %13071 : tensor -> tensor<2xindex> + %13073 = shape.cstr_broadcastable %13072, %13069 : tensor<2xindex>, tensor<2xindex> + %13074 = shape.assuming %13073 -> (tensor) { + %19688 = shape.broadcast %13072, %13069 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13071, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13066, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13075 = stablehlo.dot %13074, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %13076 = stablehlo.reshape %13031 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_4577 = tensor.dim %13046, %c0 : tensor + %13077 = arith.index_cast %dim_4577 : index to i64 + %from_elements_4578 = tensor.from_elements %13077, %c1_i64 : tensor<2xi64> + %13078 = stablehlo.dynamic_reshape %13046, %from_elements_4578 : (tensor, tensor<2xi64>) -> tensor + %dim_4579 = tensor.dim %13043, %c0 : tensor + %13079 = arith.index_cast %dim_4579 : index to i64 + %from_elements_4580 = tensor.from_elements %13079, %c1_i64 : tensor<2xi64> + %13080 = stablehlo.dynamic_reshape %13043, %from_elements_4580 : (tensor, tensor<2xi64>) -> tensor + %13081 = stablehlo.concatenate %13078, %13080, dim = 1 : (tensor, tensor) -> tensor + %13082 = "stablehlo.gather"(%13076, %13081) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13083 = shape.shape_of %13075 : tensor -> tensor<2xindex> + %13084 = shape.shape_of %13082 : tensor -> tensor<2xindex> + %13085 = shape.cstr_broadcastable %13083, %13084 : tensor<2xindex>, tensor<2xindex> + %13086 = shape.assuming %13085 -> (tensor) { + %19688 = shape.broadcast %13083, %13084 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13075, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13082, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13087 = shape.shape_of %13086 : tensor -> tensor<2xindex> + %13088 = stablehlo.dynamic_broadcast_in_dim %13086, %13087, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13089 = stablehlo.dynamic_broadcast_in_dim %213, %13087, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13090 = stablehlo.multiply %13088, %13089 : tensor + %dim_4581 = tensor.dim %13049, %c0 : tensor + %13091 = arith.index_cast %dim_4581 : index to i64 + %dim_4582 = tensor.dim %13086, %c0 : tensor + %13092 = arith.index_cast %dim_4582 : index to i64 + %13093 = arith.maxsi %13091, %13092 : i64 + %13094 = arith.index_cast %13093 : i64 to index + %from_elements_4583 = tensor.from_elements %13094, %c4096 : tensor<2xindex> + %13095 = stablehlo.dynamic_broadcast_in_dim %13049, %from_elements_4583, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4584 = tensor.dim %13095, %c0 : tensor + %13096 = arith.index_cast %dim_4584 : index to i64 + %from_elements_4585 = tensor.from_elements %13096, %c4096_i64 : tensor<2xi64> + %13097 = stablehlo.real_dynamic_slice %13090, %c_22, %from_elements_4585, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4586 = tensor.from_elements %13096, %c4096_i64, %c1_i64 : tensor<3xi64> + %13098 = stablehlo.dynamic_reshape %13095, %from_elements_4586 : (tensor, tensor<3xi64>) -> tensor + %13099 = stablehlo.dynamic_iota %from_elements_4586, dim = 1 : (tensor<3xi64>) -> tensor + %13100 = stablehlo.concatenate %13098, %13099, dim = 2 : (tensor, tensor) -> tensor + %13101 = "stablehlo.scatter"(%cst_2, %13100, %13097) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13102 = stablehlo.slice %13036 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13103 = stablehlo.reshape %13102 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13104 = stablehlo.custom_call @byteir.non_zero(%13103) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4587 = tensor.dim %13104, %c0 : tensor + %13105 = arith.index_cast %dim_4587 : index to i64 + %from_elements_4588 = tensor.from_elements %13105, %c1_i64 : tensor<2xi64> + %13106 = stablehlo.real_dynamic_slice %13104, %c_22, %from_elements_4588, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4589 = tensor.dim %13106, %c0 : tensor + %13107 = arith.index_cast %dim_4589 : index to i64 + %from_elements_4590 = tensor.from_elements %13107 : tensor<1xi64> + %13108 = stablehlo.dynamic_reshape %13106, %from_elements_4590 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4591 = tensor.from_elements %13105, %c2_i64 : tensor<2xi64> + %13109 = stablehlo.real_dynamic_slice %13104, %c_24, %from_elements_4591, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4592 = tensor.dim %13109, %c0 : tensor + %13110 = arith.index_cast %dim_4592 : index to i64 + %from_elements_4593 = tensor.from_elements %13110 : tensor<1xi64> + %13111 = stablehlo.dynamic_reshape %13109, %from_elements_4593 : (tensor, tensor<1xi64>) -> tensor + %dim_4594 = tensor.dim %13111, %c0 : tensor + %13112 = arith.index_cast %dim_4594 : index to i64 + %from_elements_4595 = tensor.from_elements %13112, %c1_i64 : tensor<2xi64> + %13113 = stablehlo.dynamic_reshape %13111, %from_elements_4595 : (tensor, tensor<2xi64>) -> tensor + %dim_4596 = tensor.dim %13113, %c0 : tensor + %13114 = arith.index_cast %dim_4596 : index to i64 + %from_elements_4597 = tensor.from_elements %c1_i64, %13114, %c4096_i64 : tensor<3xi64> + %13115 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4597, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4598 = tensor.dim %13115, %c1 : tensor<1x?x4096xi64> + %13116 = arith.index_cast %dim_4598 : index to i64 + %from_elements_4599 = tensor.from_elements %c1_i64, %13116, %c4096_i64, %c1_i64 : tensor<4xi64> + %13117 = stablehlo.dynamic_reshape %13115, %from_elements_4599 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13118 = stablehlo.dynamic_broadcast_in_dim %13113, %from_elements_4597, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4600 = tensor.dim %13118, %c1 : tensor<1x?x4096xi64> + %13119 = arith.index_cast %dim_4600 : index to i64 + %from_elements_4601 = tensor.from_elements %c1_i64, %13119, %c4096_i64, %c1_i64 : tensor<4xi64> + %13120 = stablehlo.dynamic_reshape %13118, %from_elements_4601 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13121 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4597, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4602 = tensor.dim %13121, %c1 : tensor<1x?x4096xi64> + %13122 = arith.index_cast %dim_4602 : index to i64 + %from_elements_4603 = tensor.from_elements %c1_i64, %13122, %c4096_i64, %c1_i64 : tensor<4xi64> + %13123 = stablehlo.dynamic_reshape %13121, %from_elements_4603 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13124 = stablehlo.concatenate %13117, %13120, %13123, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13125 = "stablehlo.gather"(%13047, %13124) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13126 = shape.shape_of %13125 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13127 = shape.num_elements %13126 : tensor<3xindex> -> index + %13128 = stablehlo.compute_reshape_shape %13127, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13129 = stablehlo.dynamic_reshape %13125, %13128 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13130 = stablehlo.dot %13129, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13131 = stablehlo.logistic %13130 : tensor + %13132 = shape.shape_of %13131 : tensor -> tensor<2xindex> + %13133 = shape.shape_of %13130 : tensor -> tensor<2xindex> + %13134 = shape.cstr_broadcastable %13132, %13133 : tensor<2xindex>, tensor<2xindex> + %13135 = shape.assuming %13134 -> (tensor) { + %19688 = shape.broadcast %13132, %13133 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13131, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13130, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13136 = shape.shape_of %13135 : tensor -> tensor<2xindex> + %13137 = shape.cstr_broadcastable %13136, %13133 : tensor<2xindex>, tensor<2xindex> + %13138 = shape.assuming %13137 -> (tensor) { + %19688 = shape.broadcast %13136, %13133 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13135, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13130, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13139 = stablehlo.dot %13138, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4604 = tensor.dim %13111, %c0 : tensor + %13140 = arith.index_cast %dim_4604 : index to i64 + %from_elements_4605 = tensor.from_elements %13140, %c1_i64 : tensor<2xi64> + %13141 = stablehlo.dynamic_reshape %13111, %from_elements_4605 : (tensor, tensor<2xi64>) -> tensor + %dim_4606 = tensor.dim %13108, %c0 : tensor + %13142 = arith.index_cast %dim_4606 : index to i64 + %from_elements_4607 = tensor.from_elements %13142, %c1_i64 : tensor<2xi64> + %13143 = stablehlo.dynamic_reshape %13108, %from_elements_4607 : (tensor, tensor<2xi64>) -> tensor + %13144 = stablehlo.concatenate %13141, %13143, dim = 1 : (tensor, tensor) -> tensor + %13145 = "stablehlo.gather"(%13076, %13144) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13146 = shape.shape_of %13139 : tensor -> tensor<2xindex> + %13147 = shape.shape_of %13145 : tensor -> tensor<2xindex> + %13148 = shape.cstr_broadcastable %13146, %13147 : tensor<2xindex>, tensor<2xindex> + %13149 = shape.assuming %13148 -> (tensor) { + %19688 = shape.broadcast %13146, %13147 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13139, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13145, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13150 = shape.shape_of %13149 : tensor -> tensor<2xindex> + %13151 = stablehlo.dynamic_broadcast_in_dim %13149, %13150, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13152 = stablehlo.dynamic_broadcast_in_dim %213, %13150, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13153 = stablehlo.multiply %13151, %13152 : tensor + %dim_4608 = tensor.dim %13113, %c0 : tensor + %13154 = arith.index_cast %dim_4608 : index to i64 + %dim_4609 = tensor.dim %13149, %c0 : tensor + %13155 = arith.index_cast %dim_4609 : index to i64 + %13156 = arith.maxsi %13154, %13155 : i64 + %13157 = arith.index_cast %13156 : i64 to index + %from_elements_4610 = tensor.from_elements %13157, %c4096 : tensor<2xindex> + %13158 = stablehlo.dynamic_broadcast_in_dim %13113, %from_elements_4610, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4611 = tensor.dim %13158, %c0 : tensor + %13159 = arith.index_cast %dim_4611 : index to i64 + %from_elements_4612 = tensor.from_elements %13159, %c4096_i64 : tensor<2xi64> + %13160 = stablehlo.real_dynamic_slice %13153, %c_22, %from_elements_4612, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4613 = tensor.from_elements %13159, %c4096_i64, %c1_i64 : tensor<3xi64> + %13161 = stablehlo.dynamic_reshape %13158, %from_elements_4613 : (tensor, tensor<3xi64>) -> tensor + %13162 = stablehlo.dynamic_iota %from_elements_4613, dim = 1 : (tensor<3xi64>) -> tensor + %13163 = stablehlo.concatenate %13161, %13162, dim = 2 : (tensor, tensor) -> tensor + %13164 = "stablehlo.scatter"(%13101, %13163, %13160) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13165 = stablehlo.slice %13036 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13166 = stablehlo.reshape %13165 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13167 = stablehlo.custom_call @byteir.non_zero(%13166) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4614 = tensor.dim %13167, %c0 : tensor + %13168 = arith.index_cast %dim_4614 : index to i64 + %from_elements_4615 = tensor.from_elements %13168, %c1_i64 : tensor<2xi64> + %13169 = stablehlo.real_dynamic_slice %13167, %c_22, %from_elements_4615, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4616 = tensor.dim %13169, %c0 : tensor + %13170 = arith.index_cast %dim_4616 : index to i64 + %from_elements_4617 = tensor.from_elements %13170 : tensor<1xi64> + %13171 = stablehlo.dynamic_reshape %13169, %from_elements_4617 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4618 = tensor.from_elements %13168, %c2_i64 : tensor<2xi64> + %13172 = stablehlo.real_dynamic_slice %13167, %c_24, %from_elements_4618, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4619 = tensor.dim %13172, %c0 : tensor + %13173 = arith.index_cast %dim_4619 : index to i64 + %from_elements_4620 = tensor.from_elements %13173 : tensor<1xi64> + %13174 = stablehlo.dynamic_reshape %13172, %from_elements_4620 : (tensor, tensor<1xi64>) -> tensor + %dim_4621 = tensor.dim %13174, %c0 : tensor + %13175 = arith.index_cast %dim_4621 : index to i64 + %from_elements_4622 = tensor.from_elements %13175, %c1_i64 : tensor<2xi64> + %13176 = stablehlo.dynamic_reshape %13174, %from_elements_4622 : (tensor, tensor<2xi64>) -> tensor + %dim_4623 = tensor.dim %13176, %c0 : tensor + %13177 = arith.index_cast %dim_4623 : index to i64 + %from_elements_4624 = tensor.from_elements %c1_i64, %13177, %c4096_i64 : tensor<3xi64> + %13178 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4624, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4625 = tensor.dim %13178, %c1 : tensor<1x?x4096xi64> + %13179 = arith.index_cast %dim_4625 : index to i64 + %from_elements_4626 = tensor.from_elements %c1_i64, %13179, %c4096_i64, %c1_i64 : tensor<4xi64> + %13180 = stablehlo.dynamic_reshape %13178, %from_elements_4626 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13181 = stablehlo.dynamic_broadcast_in_dim %13176, %from_elements_4624, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4627 = tensor.dim %13181, %c1 : tensor<1x?x4096xi64> + %13182 = arith.index_cast %dim_4627 : index to i64 + %from_elements_4628 = tensor.from_elements %c1_i64, %13182, %c4096_i64, %c1_i64 : tensor<4xi64> + %13183 = stablehlo.dynamic_reshape %13181, %from_elements_4628 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13184 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4624, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4629 = tensor.dim %13184, %c1 : tensor<1x?x4096xi64> + %13185 = arith.index_cast %dim_4629 : index to i64 + %from_elements_4630 = tensor.from_elements %c1_i64, %13185, %c4096_i64, %c1_i64 : tensor<4xi64> + %13186 = stablehlo.dynamic_reshape %13184, %from_elements_4630 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13187 = stablehlo.concatenate %13180, %13183, %13186, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13188 = "stablehlo.gather"(%13047, %13187) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13189 = shape.shape_of %13188 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13190 = shape.num_elements %13189 : tensor<3xindex> -> index + %13191 = stablehlo.compute_reshape_shape %13190, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13192 = stablehlo.dynamic_reshape %13188, %13191 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13193 = stablehlo.dot %13192, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13194 = stablehlo.logistic %13193 : tensor + %13195 = shape.shape_of %13194 : tensor -> tensor<2xindex> + %13196 = shape.shape_of %13193 : tensor -> tensor<2xindex> + %13197 = shape.cstr_broadcastable %13195, %13196 : tensor<2xindex>, tensor<2xindex> + %13198 = shape.assuming %13197 -> (tensor) { + %19688 = shape.broadcast %13195, %13196 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13194, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13193, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13199 = shape.shape_of %13198 : tensor -> tensor<2xindex> + %13200 = shape.cstr_broadcastable %13199, %13196 : tensor<2xindex>, tensor<2xindex> + %13201 = shape.assuming %13200 -> (tensor) { + %19688 = shape.broadcast %13199, %13196 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13198, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13193, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13202 = stablehlo.dot %13201, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4631 = tensor.dim %13174, %c0 : tensor + %13203 = arith.index_cast %dim_4631 : index to i64 + %from_elements_4632 = tensor.from_elements %13203, %c1_i64 : tensor<2xi64> + %13204 = stablehlo.dynamic_reshape %13174, %from_elements_4632 : (tensor, tensor<2xi64>) -> tensor + %dim_4633 = tensor.dim %13171, %c0 : tensor + %13205 = arith.index_cast %dim_4633 : index to i64 + %from_elements_4634 = tensor.from_elements %13205, %c1_i64 : tensor<2xi64> + %13206 = stablehlo.dynamic_reshape %13171, %from_elements_4634 : (tensor, tensor<2xi64>) -> tensor + %13207 = stablehlo.concatenate %13204, %13206, dim = 1 : (tensor, tensor) -> tensor + %13208 = "stablehlo.gather"(%13076, %13207) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13209 = shape.shape_of %13202 : tensor -> tensor<2xindex> + %13210 = shape.shape_of %13208 : tensor -> tensor<2xindex> + %13211 = shape.cstr_broadcastable %13209, %13210 : tensor<2xindex>, tensor<2xindex> + %13212 = shape.assuming %13211 -> (tensor) { + %19688 = shape.broadcast %13209, %13210 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13202, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13208, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13213 = shape.shape_of %13212 : tensor -> tensor<2xindex> + %13214 = stablehlo.dynamic_broadcast_in_dim %13212, %13213, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13215 = stablehlo.dynamic_broadcast_in_dim %213, %13213, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13216 = stablehlo.multiply %13214, %13215 : tensor + %dim_4635 = tensor.dim %13176, %c0 : tensor + %13217 = arith.index_cast %dim_4635 : index to i64 + %dim_4636 = tensor.dim %13212, %c0 : tensor + %13218 = arith.index_cast %dim_4636 : index to i64 + %13219 = arith.maxsi %13217, %13218 : i64 + %13220 = arith.index_cast %13219 : i64 to index + %from_elements_4637 = tensor.from_elements %13220, %c4096 : tensor<2xindex> + %13221 = stablehlo.dynamic_broadcast_in_dim %13176, %from_elements_4637, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4638 = tensor.dim %13221, %c0 : tensor + %13222 = arith.index_cast %dim_4638 : index to i64 + %from_elements_4639 = tensor.from_elements %13222, %c4096_i64 : tensor<2xi64> + %13223 = stablehlo.real_dynamic_slice %13216, %c_22, %from_elements_4639, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4640 = tensor.from_elements %13222, %c4096_i64, %c1_i64 : tensor<3xi64> + %13224 = stablehlo.dynamic_reshape %13221, %from_elements_4640 : (tensor, tensor<3xi64>) -> tensor + %13225 = stablehlo.dynamic_iota %from_elements_4640, dim = 1 : (tensor<3xi64>) -> tensor + %13226 = stablehlo.concatenate %13224, %13225, dim = 2 : (tensor, tensor) -> tensor + %13227 = "stablehlo.scatter"(%13164, %13226, %13223) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13228 = stablehlo.slice %13036 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13229 = stablehlo.reshape %13228 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13230 = stablehlo.custom_call @byteir.non_zero(%13229) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4641 = tensor.dim %13230, %c0 : tensor + %13231 = arith.index_cast %dim_4641 : index to i64 + %from_elements_4642 = tensor.from_elements %13231, %c1_i64 : tensor<2xi64> + %13232 = stablehlo.real_dynamic_slice %13230, %c_22, %from_elements_4642, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4643 = tensor.dim %13232, %c0 : tensor + %13233 = arith.index_cast %dim_4643 : index to i64 + %from_elements_4644 = tensor.from_elements %13233 : tensor<1xi64> + %13234 = stablehlo.dynamic_reshape %13232, %from_elements_4644 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4645 = tensor.from_elements %13231, %c2_i64 : tensor<2xi64> + %13235 = stablehlo.real_dynamic_slice %13230, %c_24, %from_elements_4645, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4646 = tensor.dim %13235, %c0 : tensor + %13236 = arith.index_cast %dim_4646 : index to i64 + %from_elements_4647 = tensor.from_elements %13236 : tensor<1xi64> + %13237 = stablehlo.dynamic_reshape %13235, %from_elements_4647 : (tensor, tensor<1xi64>) -> tensor + %dim_4648 = tensor.dim %13237, %c0 : tensor + %13238 = arith.index_cast %dim_4648 : index to i64 + %from_elements_4649 = tensor.from_elements %13238, %c1_i64 : tensor<2xi64> + %13239 = stablehlo.dynamic_reshape %13237, %from_elements_4649 : (tensor, tensor<2xi64>) -> tensor + %dim_4650 = tensor.dim %13239, %c0 : tensor + %13240 = arith.index_cast %dim_4650 : index to i64 + %from_elements_4651 = tensor.from_elements %c1_i64, %13240, %c4096_i64 : tensor<3xi64> + %13241 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4651, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4652 = tensor.dim %13241, %c1 : tensor<1x?x4096xi64> + %13242 = arith.index_cast %dim_4652 : index to i64 + %from_elements_4653 = tensor.from_elements %c1_i64, %13242, %c4096_i64, %c1_i64 : tensor<4xi64> + %13243 = stablehlo.dynamic_reshape %13241, %from_elements_4653 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13244 = stablehlo.dynamic_broadcast_in_dim %13239, %from_elements_4651, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4654 = tensor.dim %13244, %c1 : tensor<1x?x4096xi64> + %13245 = arith.index_cast %dim_4654 : index to i64 + %from_elements_4655 = tensor.from_elements %c1_i64, %13245, %c4096_i64, %c1_i64 : tensor<4xi64> + %13246 = stablehlo.dynamic_reshape %13244, %from_elements_4655 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13247 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4651, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4656 = tensor.dim %13247, %c1 : tensor<1x?x4096xi64> + %13248 = arith.index_cast %dim_4656 : index to i64 + %from_elements_4657 = tensor.from_elements %c1_i64, %13248, %c4096_i64, %c1_i64 : tensor<4xi64> + %13249 = stablehlo.dynamic_reshape %13247, %from_elements_4657 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13250 = stablehlo.concatenate %13243, %13246, %13249, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13251 = "stablehlo.gather"(%13047, %13250) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13252 = shape.shape_of %13251 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13253 = shape.num_elements %13252 : tensor<3xindex> -> index + %13254 = stablehlo.compute_reshape_shape %13253, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13255 = stablehlo.dynamic_reshape %13251, %13254 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13256 = stablehlo.dot %13255, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13257 = stablehlo.logistic %13256 : tensor + %13258 = shape.shape_of %13257 : tensor -> tensor<2xindex> + %13259 = shape.shape_of %13256 : tensor -> tensor<2xindex> + %13260 = shape.cstr_broadcastable %13258, %13259 : tensor<2xindex>, tensor<2xindex> + %13261 = shape.assuming %13260 -> (tensor) { + %19688 = shape.broadcast %13258, %13259 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13257, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13256, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13262 = shape.shape_of %13261 : tensor -> tensor<2xindex> + %13263 = shape.cstr_broadcastable %13262, %13259 : tensor<2xindex>, tensor<2xindex> + %13264 = shape.assuming %13263 -> (tensor) { + %19688 = shape.broadcast %13262, %13259 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13261, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13256, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13265 = stablehlo.dot %13264, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4658 = tensor.dim %13237, %c0 : tensor + %13266 = arith.index_cast %dim_4658 : index to i64 + %from_elements_4659 = tensor.from_elements %13266, %c1_i64 : tensor<2xi64> + %13267 = stablehlo.dynamic_reshape %13237, %from_elements_4659 : (tensor, tensor<2xi64>) -> tensor + %dim_4660 = tensor.dim %13234, %c0 : tensor + %13268 = arith.index_cast %dim_4660 : index to i64 + %from_elements_4661 = tensor.from_elements %13268, %c1_i64 : tensor<2xi64> + %13269 = stablehlo.dynamic_reshape %13234, %from_elements_4661 : (tensor, tensor<2xi64>) -> tensor + %13270 = stablehlo.concatenate %13267, %13269, dim = 1 : (tensor, tensor) -> tensor + %13271 = "stablehlo.gather"(%13076, %13270) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13272 = shape.shape_of %13265 : tensor -> tensor<2xindex> + %13273 = shape.shape_of %13271 : tensor -> tensor<2xindex> + %13274 = shape.cstr_broadcastable %13272, %13273 : tensor<2xindex>, tensor<2xindex> + %13275 = shape.assuming %13274 -> (tensor) { + %19688 = shape.broadcast %13272, %13273 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13265, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13271, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13276 = shape.shape_of %13275 : tensor -> tensor<2xindex> + %13277 = stablehlo.dynamic_broadcast_in_dim %13275, %13276, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13278 = stablehlo.dynamic_broadcast_in_dim %213, %13276, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13279 = stablehlo.multiply %13277, %13278 : tensor + %dim_4662 = tensor.dim %13239, %c0 : tensor + %13280 = arith.index_cast %dim_4662 : index to i64 + %dim_4663 = tensor.dim %13275, %c0 : tensor + %13281 = arith.index_cast %dim_4663 : index to i64 + %13282 = arith.maxsi %13280, %13281 : i64 + %13283 = arith.index_cast %13282 : i64 to index + %from_elements_4664 = tensor.from_elements %13283, %c4096 : tensor<2xindex> + %13284 = stablehlo.dynamic_broadcast_in_dim %13239, %from_elements_4664, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4665 = tensor.dim %13284, %c0 : tensor + %13285 = arith.index_cast %dim_4665 : index to i64 + %from_elements_4666 = tensor.from_elements %13285, %c4096_i64 : tensor<2xi64> + %13286 = stablehlo.real_dynamic_slice %13279, %c_22, %from_elements_4666, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4667 = tensor.from_elements %13285, %c4096_i64, %c1_i64 : tensor<3xi64> + %13287 = stablehlo.dynamic_reshape %13284, %from_elements_4667 : (tensor, tensor<3xi64>) -> tensor + %13288 = stablehlo.dynamic_iota %from_elements_4667, dim = 1 : (tensor<3xi64>) -> tensor + %13289 = stablehlo.concatenate %13287, %13288, dim = 2 : (tensor, tensor) -> tensor + %13290 = "stablehlo.scatter"(%13227, %13289, %13286) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13291 = stablehlo.slice %13036 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13292 = stablehlo.reshape %13291 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13293 = stablehlo.custom_call @byteir.non_zero(%13292) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4668 = tensor.dim %13293, %c0 : tensor + %13294 = arith.index_cast %dim_4668 : index to i64 + %from_elements_4669 = tensor.from_elements %13294, %c1_i64 : tensor<2xi64> + %13295 = stablehlo.real_dynamic_slice %13293, %c_22, %from_elements_4669, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4670 = tensor.dim %13295, %c0 : tensor + %13296 = arith.index_cast %dim_4670 : index to i64 + %from_elements_4671 = tensor.from_elements %13296 : tensor<1xi64> + %13297 = stablehlo.dynamic_reshape %13295, %from_elements_4671 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4672 = tensor.from_elements %13294, %c2_i64 : tensor<2xi64> + %13298 = stablehlo.real_dynamic_slice %13293, %c_24, %from_elements_4672, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4673 = tensor.dim %13298, %c0 : tensor + %13299 = arith.index_cast %dim_4673 : index to i64 + %from_elements_4674 = tensor.from_elements %13299 : tensor<1xi64> + %13300 = stablehlo.dynamic_reshape %13298, %from_elements_4674 : (tensor, tensor<1xi64>) -> tensor + %dim_4675 = tensor.dim %13300, %c0 : tensor + %13301 = arith.index_cast %dim_4675 : index to i64 + %from_elements_4676 = tensor.from_elements %13301, %c1_i64 : tensor<2xi64> + %13302 = stablehlo.dynamic_reshape %13300, %from_elements_4676 : (tensor, tensor<2xi64>) -> tensor + %dim_4677 = tensor.dim %13302, %c0 : tensor + %13303 = arith.index_cast %dim_4677 : index to i64 + %from_elements_4678 = tensor.from_elements %c1_i64, %13303, %c4096_i64 : tensor<3xi64> + %13304 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4678, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4679 = tensor.dim %13304, %c1 : tensor<1x?x4096xi64> + %13305 = arith.index_cast %dim_4679 : index to i64 + %from_elements_4680 = tensor.from_elements %c1_i64, %13305, %c4096_i64, %c1_i64 : tensor<4xi64> + %13306 = stablehlo.dynamic_reshape %13304, %from_elements_4680 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13307 = stablehlo.dynamic_broadcast_in_dim %13302, %from_elements_4678, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4681 = tensor.dim %13307, %c1 : tensor<1x?x4096xi64> + %13308 = arith.index_cast %dim_4681 : index to i64 + %from_elements_4682 = tensor.from_elements %c1_i64, %13308, %c4096_i64, %c1_i64 : tensor<4xi64> + %13309 = stablehlo.dynamic_reshape %13307, %from_elements_4682 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13310 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4678, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4683 = tensor.dim %13310, %c1 : tensor<1x?x4096xi64> + %13311 = arith.index_cast %dim_4683 : index to i64 + %from_elements_4684 = tensor.from_elements %c1_i64, %13311, %c4096_i64, %c1_i64 : tensor<4xi64> + %13312 = stablehlo.dynamic_reshape %13310, %from_elements_4684 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13313 = stablehlo.concatenate %13306, %13309, %13312, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13314 = "stablehlo.gather"(%13047, %13313) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13315 = shape.shape_of %13314 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13316 = shape.num_elements %13315 : tensor<3xindex> -> index + %13317 = stablehlo.compute_reshape_shape %13316, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13318 = stablehlo.dynamic_reshape %13314, %13317 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13319 = stablehlo.dot %13318, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13320 = stablehlo.logistic %13319 : tensor + %13321 = shape.shape_of %13320 : tensor -> tensor<2xindex> + %13322 = shape.shape_of %13319 : tensor -> tensor<2xindex> + %13323 = shape.cstr_broadcastable %13321, %13322 : tensor<2xindex>, tensor<2xindex> + %13324 = shape.assuming %13323 -> (tensor) { + %19688 = shape.broadcast %13321, %13322 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13320, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13319, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13325 = shape.shape_of %13324 : tensor -> tensor<2xindex> + %13326 = shape.cstr_broadcastable %13325, %13322 : tensor<2xindex>, tensor<2xindex> + %13327 = shape.assuming %13326 -> (tensor) { + %19688 = shape.broadcast %13325, %13322 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13324, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13319, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13328 = stablehlo.dot %13327, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4685 = tensor.dim %13300, %c0 : tensor + %13329 = arith.index_cast %dim_4685 : index to i64 + %from_elements_4686 = tensor.from_elements %13329, %c1_i64 : tensor<2xi64> + %13330 = stablehlo.dynamic_reshape %13300, %from_elements_4686 : (tensor, tensor<2xi64>) -> tensor + %dim_4687 = tensor.dim %13297, %c0 : tensor + %13331 = arith.index_cast %dim_4687 : index to i64 + %from_elements_4688 = tensor.from_elements %13331, %c1_i64 : tensor<2xi64> + %13332 = stablehlo.dynamic_reshape %13297, %from_elements_4688 : (tensor, tensor<2xi64>) -> tensor + %13333 = stablehlo.concatenate %13330, %13332, dim = 1 : (tensor, tensor) -> tensor + %13334 = "stablehlo.gather"(%13076, %13333) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13335 = shape.shape_of %13328 : tensor -> tensor<2xindex> + %13336 = shape.shape_of %13334 : tensor -> tensor<2xindex> + %13337 = shape.cstr_broadcastable %13335, %13336 : tensor<2xindex>, tensor<2xindex> + %13338 = shape.assuming %13337 -> (tensor) { + %19688 = shape.broadcast %13335, %13336 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13328, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13334, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13339 = shape.shape_of %13338 : tensor -> tensor<2xindex> + %13340 = stablehlo.dynamic_broadcast_in_dim %13338, %13339, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13341 = stablehlo.dynamic_broadcast_in_dim %213, %13339, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13342 = stablehlo.multiply %13340, %13341 : tensor + %dim_4689 = tensor.dim %13302, %c0 : tensor + %13343 = arith.index_cast %dim_4689 : index to i64 + %dim_4690 = tensor.dim %13338, %c0 : tensor + %13344 = arith.index_cast %dim_4690 : index to i64 + %13345 = arith.maxsi %13343, %13344 : i64 + %13346 = arith.index_cast %13345 : i64 to index + %from_elements_4691 = tensor.from_elements %13346, %c4096 : tensor<2xindex> + %13347 = stablehlo.dynamic_broadcast_in_dim %13302, %from_elements_4691, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4692 = tensor.dim %13347, %c0 : tensor + %13348 = arith.index_cast %dim_4692 : index to i64 + %from_elements_4693 = tensor.from_elements %13348, %c4096_i64 : tensor<2xi64> + %13349 = stablehlo.real_dynamic_slice %13342, %c_22, %from_elements_4693, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4694 = tensor.from_elements %13348, %c4096_i64, %c1_i64 : tensor<3xi64> + %13350 = stablehlo.dynamic_reshape %13347, %from_elements_4694 : (tensor, tensor<3xi64>) -> tensor + %13351 = stablehlo.dynamic_iota %from_elements_4694, dim = 1 : (tensor<3xi64>) -> tensor + %13352 = stablehlo.concatenate %13350, %13351, dim = 2 : (tensor, tensor) -> tensor + %13353 = "stablehlo.scatter"(%13290, %13352, %13349) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13354 = stablehlo.slice %13036 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13355 = stablehlo.reshape %13354 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13356 = stablehlo.custom_call @byteir.non_zero(%13355) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4695 = tensor.dim %13356, %c0 : tensor + %13357 = arith.index_cast %dim_4695 : index to i64 + %from_elements_4696 = tensor.from_elements %13357, %c1_i64 : tensor<2xi64> + %13358 = stablehlo.real_dynamic_slice %13356, %c_22, %from_elements_4696, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4697 = tensor.dim %13358, %c0 : tensor + %13359 = arith.index_cast %dim_4697 : index to i64 + %from_elements_4698 = tensor.from_elements %13359 : tensor<1xi64> + %13360 = stablehlo.dynamic_reshape %13358, %from_elements_4698 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4699 = tensor.from_elements %13357, %c2_i64 : tensor<2xi64> + %13361 = stablehlo.real_dynamic_slice %13356, %c_24, %from_elements_4699, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4700 = tensor.dim %13361, %c0 : tensor + %13362 = arith.index_cast %dim_4700 : index to i64 + %from_elements_4701 = tensor.from_elements %13362 : tensor<1xi64> + %13363 = stablehlo.dynamic_reshape %13361, %from_elements_4701 : (tensor, tensor<1xi64>) -> tensor + %dim_4702 = tensor.dim %13363, %c0 : tensor + %13364 = arith.index_cast %dim_4702 : index to i64 + %from_elements_4703 = tensor.from_elements %13364, %c1_i64 : tensor<2xi64> + %13365 = stablehlo.dynamic_reshape %13363, %from_elements_4703 : (tensor, tensor<2xi64>) -> tensor + %dim_4704 = tensor.dim %13365, %c0 : tensor + %13366 = arith.index_cast %dim_4704 : index to i64 + %from_elements_4705 = tensor.from_elements %c1_i64, %13366, %c4096_i64 : tensor<3xi64> + %13367 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4705, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4706 = tensor.dim %13367, %c1 : tensor<1x?x4096xi64> + %13368 = arith.index_cast %dim_4706 : index to i64 + %from_elements_4707 = tensor.from_elements %c1_i64, %13368, %c4096_i64, %c1_i64 : tensor<4xi64> + %13369 = stablehlo.dynamic_reshape %13367, %from_elements_4707 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13370 = stablehlo.dynamic_broadcast_in_dim %13365, %from_elements_4705, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4708 = tensor.dim %13370, %c1 : tensor<1x?x4096xi64> + %13371 = arith.index_cast %dim_4708 : index to i64 + %from_elements_4709 = tensor.from_elements %c1_i64, %13371, %c4096_i64, %c1_i64 : tensor<4xi64> + %13372 = stablehlo.dynamic_reshape %13370, %from_elements_4709 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13373 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4705, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4710 = tensor.dim %13373, %c1 : tensor<1x?x4096xi64> + %13374 = arith.index_cast %dim_4710 : index to i64 + %from_elements_4711 = tensor.from_elements %c1_i64, %13374, %c4096_i64, %c1_i64 : tensor<4xi64> + %13375 = stablehlo.dynamic_reshape %13373, %from_elements_4711 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13376 = stablehlo.concatenate %13369, %13372, %13375, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13377 = "stablehlo.gather"(%13047, %13376) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13378 = shape.shape_of %13377 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13379 = shape.num_elements %13378 : tensor<3xindex> -> index + %13380 = stablehlo.compute_reshape_shape %13379, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13381 = stablehlo.dynamic_reshape %13377, %13380 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13382 = stablehlo.dot %13381, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13383 = stablehlo.logistic %13382 : tensor + %13384 = shape.shape_of %13383 : tensor -> tensor<2xindex> + %13385 = shape.shape_of %13382 : tensor -> tensor<2xindex> + %13386 = shape.cstr_broadcastable %13384, %13385 : tensor<2xindex>, tensor<2xindex> + %13387 = shape.assuming %13386 -> (tensor) { + %19688 = shape.broadcast %13384, %13385 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13383, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13382, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13388 = shape.shape_of %13387 : tensor -> tensor<2xindex> + %13389 = shape.cstr_broadcastable %13388, %13385 : tensor<2xindex>, tensor<2xindex> + %13390 = shape.assuming %13389 -> (tensor) { + %19688 = shape.broadcast %13388, %13385 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13387, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13382, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13391 = stablehlo.dot %13390, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4712 = tensor.dim %13363, %c0 : tensor + %13392 = arith.index_cast %dim_4712 : index to i64 + %from_elements_4713 = tensor.from_elements %13392, %c1_i64 : tensor<2xi64> + %13393 = stablehlo.dynamic_reshape %13363, %from_elements_4713 : (tensor, tensor<2xi64>) -> tensor + %dim_4714 = tensor.dim %13360, %c0 : tensor + %13394 = arith.index_cast %dim_4714 : index to i64 + %from_elements_4715 = tensor.from_elements %13394, %c1_i64 : tensor<2xi64> + %13395 = stablehlo.dynamic_reshape %13360, %from_elements_4715 : (tensor, tensor<2xi64>) -> tensor + %13396 = stablehlo.concatenate %13393, %13395, dim = 1 : (tensor, tensor) -> tensor + %13397 = "stablehlo.gather"(%13076, %13396) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13398 = shape.shape_of %13391 : tensor -> tensor<2xindex> + %13399 = shape.shape_of %13397 : tensor -> tensor<2xindex> + %13400 = shape.cstr_broadcastable %13398, %13399 : tensor<2xindex>, tensor<2xindex> + %13401 = shape.assuming %13400 -> (tensor) { + %19688 = shape.broadcast %13398, %13399 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13391, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13397, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13402 = shape.shape_of %13401 : tensor -> tensor<2xindex> + %13403 = stablehlo.dynamic_broadcast_in_dim %13401, %13402, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13404 = stablehlo.dynamic_broadcast_in_dim %213, %13402, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13405 = stablehlo.multiply %13403, %13404 : tensor + %dim_4716 = tensor.dim %13365, %c0 : tensor + %13406 = arith.index_cast %dim_4716 : index to i64 + %dim_4717 = tensor.dim %13401, %c0 : tensor + %13407 = arith.index_cast %dim_4717 : index to i64 + %13408 = arith.maxsi %13406, %13407 : i64 + %13409 = arith.index_cast %13408 : i64 to index + %from_elements_4718 = tensor.from_elements %13409, %c4096 : tensor<2xindex> + %13410 = stablehlo.dynamic_broadcast_in_dim %13365, %from_elements_4718, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4719 = tensor.dim %13410, %c0 : tensor + %13411 = arith.index_cast %dim_4719 : index to i64 + %from_elements_4720 = tensor.from_elements %13411, %c4096_i64 : tensor<2xi64> + %13412 = stablehlo.real_dynamic_slice %13405, %c_22, %from_elements_4720, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4721 = tensor.from_elements %13411, %c4096_i64, %c1_i64 : tensor<3xi64> + %13413 = stablehlo.dynamic_reshape %13410, %from_elements_4721 : (tensor, tensor<3xi64>) -> tensor + %13414 = stablehlo.dynamic_iota %from_elements_4721, dim = 1 : (tensor<3xi64>) -> tensor + %13415 = stablehlo.concatenate %13413, %13414, dim = 2 : (tensor, tensor) -> tensor + %13416 = "stablehlo.scatter"(%13353, %13415, %13412) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13417 = stablehlo.slice %13036 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13418 = stablehlo.reshape %13417 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13419 = stablehlo.custom_call @byteir.non_zero(%13418) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4722 = tensor.dim %13419, %c0 : tensor + %13420 = arith.index_cast %dim_4722 : index to i64 + %from_elements_4723 = tensor.from_elements %13420, %c1_i64 : tensor<2xi64> + %13421 = stablehlo.real_dynamic_slice %13419, %c_22, %from_elements_4723, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4724 = tensor.dim %13421, %c0 : tensor + %13422 = arith.index_cast %dim_4724 : index to i64 + %from_elements_4725 = tensor.from_elements %13422 : tensor<1xi64> + %13423 = stablehlo.dynamic_reshape %13421, %from_elements_4725 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4726 = tensor.from_elements %13420, %c2_i64 : tensor<2xi64> + %13424 = stablehlo.real_dynamic_slice %13419, %c_24, %from_elements_4726, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4727 = tensor.dim %13424, %c0 : tensor + %13425 = arith.index_cast %dim_4727 : index to i64 + %from_elements_4728 = tensor.from_elements %13425 : tensor<1xi64> + %13426 = stablehlo.dynamic_reshape %13424, %from_elements_4728 : (tensor, tensor<1xi64>) -> tensor + %dim_4729 = tensor.dim %13426, %c0 : tensor + %13427 = arith.index_cast %dim_4729 : index to i64 + %from_elements_4730 = tensor.from_elements %13427, %c1_i64 : tensor<2xi64> + %13428 = stablehlo.dynamic_reshape %13426, %from_elements_4730 : (tensor, tensor<2xi64>) -> tensor + %dim_4731 = tensor.dim %13428, %c0 : tensor + %13429 = arith.index_cast %dim_4731 : index to i64 + %from_elements_4732 = tensor.from_elements %c1_i64, %13429, %c4096_i64 : tensor<3xi64> + %13430 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4732, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4733 = tensor.dim %13430, %c1 : tensor<1x?x4096xi64> + %13431 = arith.index_cast %dim_4733 : index to i64 + %from_elements_4734 = tensor.from_elements %c1_i64, %13431, %c4096_i64, %c1_i64 : tensor<4xi64> + %13432 = stablehlo.dynamic_reshape %13430, %from_elements_4734 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13433 = stablehlo.dynamic_broadcast_in_dim %13428, %from_elements_4732, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4735 = tensor.dim %13433, %c1 : tensor<1x?x4096xi64> + %13434 = arith.index_cast %dim_4735 : index to i64 + %from_elements_4736 = tensor.from_elements %c1_i64, %13434, %c4096_i64, %c1_i64 : tensor<4xi64> + %13435 = stablehlo.dynamic_reshape %13433, %from_elements_4736 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13436 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4732, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4737 = tensor.dim %13436, %c1 : tensor<1x?x4096xi64> + %13437 = arith.index_cast %dim_4737 : index to i64 + %from_elements_4738 = tensor.from_elements %c1_i64, %13437, %c4096_i64, %c1_i64 : tensor<4xi64> + %13438 = stablehlo.dynamic_reshape %13436, %from_elements_4738 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13439 = stablehlo.concatenate %13432, %13435, %13438, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13440 = "stablehlo.gather"(%13047, %13439) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13441 = shape.shape_of %13440 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13442 = shape.num_elements %13441 : tensor<3xindex> -> index + %13443 = stablehlo.compute_reshape_shape %13442, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13444 = stablehlo.dynamic_reshape %13440, %13443 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13445 = stablehlo.dot %13444, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13446 = stablehlo.logistic %13445 : tensor + %13447 = shape.shape_of %13446 : tensor -> tensor<2xindex> + %13448 = shape.shape_of %13445 : tensor -> tensor<2xindex> + %13449 = shape.cstr_broadcastable %13447, %13448 : tensor<2xindex>, tensor<2xindex> + %13450 = shape.assuming %13449 -> (tensor) { + %19688 = shape.broadcast %13447, %13448 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13446, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13445, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13451 = shape.shape_of %13450 : tensor -> tensor<2xindex> + %13452 = shape.cstr_broadcastable %13451, %13448 : tensor<2xindex>, tensor<2xindex> + %13453 = shape.assuming %13452 -> (tensor) { + %19688 = shape.broadcast %13451, %13448 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13450, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13445, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13454 = stablehlo.dot %13453, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4739 = tensor.dim %13426, %c0 : tensor + %13455 = arith.index_cast %dim_4739 : index to i64 + %from_elements_4740 = tensor.from_elements %13455, %c1_i64 : tensor<2xi64> + %13456 = stablehlo.dynamic_reshape %13426, %from_elements_4740 : (tensor, tensor<2xi64>) -> tensor + %dim_4741 = tensor.dim %13423, %c0 : tensor + %13457 = arith.index_cast %dim_4741 : index to i64 + %from_elements_4742 = tensor.from_elements %13457, %c1_i64 : tensor<2xi64> + %13458 = stablehlo.dynamic_reshape %13423, %from_elements_4742 : (tensor, tensor<2xi64>) -> tensor + %13459 = stablehlo.concatenate %13456, %13458, dim = 1 : (tensor, tensor) -> tensor + %13460 = "stablehlo.gather"(%13076, %13459) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13461 = shape.shape_of %13454 : tensor -> tensor<2xindex> + %13462 = shape.shape_of %13460 : tensor -> tensor<2xindex> + %13463 = shape.cstr_broadcastable %13461, %13462 : tensor<2xindex>, tensor<2xindex> + %13464 = shape.assuming %13463 -> (tensor) { + %19688 = shape.broadcast %13461, %13462 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13454, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13460, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13465 = shape.shape_of %13464 : tensor -> tensor<2xindex> + %13466 = stablehlo.dynamic_broadcast_in_dim %13464, %13465, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13467 = stablehlo.dynamic_broadcast_in_dim %213, %13465, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13468 = stablehlo.multiply %13466, %13467 : tensor + %dim_4743 = tensor.dim %13428, %c0 : tensor + %13469 = arith.index_cast %dim_4743 : index to i64 + %dim_4744 = tensor.dim %13464, %c0 : tensor + %13470 = arith.index_cast %dim_4744 : index to i64 + %13471 = arith.maxsi %13469, %13470 : i64 + %13472 = arith.index_cast %13471 : i64 to index + %from_elements_4745 = tensor.from_elements %13472, %c4096 : tensor<2xindex> + %13473 = stablehlo.dynamic_broadcast_in_dim %13428, %from_elements_4745, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4746 = tensor.dim %13473, %c0 : tensor + %13474 = arith.index_cast %dim_4746 : index to i64 + %from_elements_4747 = tensor.from_elements %13474, %c4096_i64 : tensor<2xi64> + %13475 = stablehlo.real_dynamic_slice %13468, %c_22, %from_elements_4747, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4748 = tensor.from_elements %13474, %c4096_i64, %c1_i64 : tensor<3xi64> + %13476 = stablehlo.dynamic_reshape %13473, %from_elements_4748 : (tensor, tensor<3xi64>) -> tensor + %13477 = stablehlo.dynamic_iota %from_elements_4748, dim = 1 : (tensor<3xi64>) -> tensor + %13478 = stablehlo.concatenate %13476, %13477, dim = 2 : (tensor, tensor) -> tensor + %13479 = "stablehlo.scatter"(%13416, %13478, %13475) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13480 = stablehlo.slice %13036 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13481 = stablehlo.reshape %13480 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13482 = stablehlo.custom_call @byteir.non_zero(%13481) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4749 = tensor.dim %13482, %c0 : tensor + %13483 = arith.index_cast %dim_4749 : index to i64 + %from_elements_4750 = tensor.from_elements %13483, %c1_i64 : tensor<2xi64> + %13484 = stablehlo.real_dynamic_slice %13482, %c_22, %from_elements_4750, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4751 = tensor.dim %13484, %c0 : tensor + %13485 = arith.index_cast %dim_4751 : index to i64 + %from_elements_4752 = tensor.from_elements %13485 : tensor<1xi64> + %13486 = stablehlo.dynamic_reshape %13484, %from_elements_4752 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4753 = tensor.from_elements %13483, %c2_i64 : tensor<2xi64> + %13487 = stablehlo.real_dynamic_slice %13482, %c_24, %from_elements_4753, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4754 = tensor.dim %13487, %c0 : tensor + %13488 = arith.index_cast %dim_4754 : index to i64 + %from_elements_4755 = tensor.from_elements %13488 : tensor<1xi64> + %13489 = stablehlo.dynamic_reshape %13487, %from_elements_4755 : (tensor, tensor<1xi64>) -> tensor + %dim_4756 = tensor.dim %13489, %c0 : tensor + %13490 = arith.index_cast %dim_4756 : index to i64 + %from_elements_4757 = tensor.from_elements %13490, %c1_i64 : tensor<2xi64> + %13491 = stablehlo.dynamic_reshape %13489, %from_elements_4757 : (tensor, tensor<2xi64>) -> tensor + %dim_4758 = tensor.dim %13491, %c0 : tensor + %13492 = arith.index_cast %dim_4758 : index to i64 + %from_elements_4759 = tensor.from_elements %c1_i64, %13492, %c4096_i64 : tensor<3xi64> + %13493 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4759, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4760 = tensor.dim %13493, %c1 : tensor<1x?x4096xi64> + %13494 = arith.index_cast %dim_4760 : index to i64 + %from_elements_4761 = tensor.from_elements %c1_i64, %13494, %c4096_i64, %c1_i64 : tensor<4xi64> + %13495 = stablehlo.dynamic_reshape %13493, %from_elements_4761 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13496 = stablehlo.dynamic_broadcast_in_dim %13491, %from_elements_4759, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4762 = tensor.dim %13496, %c1 : tensor<1x?x4096xi64> + %13497 = arith.index_cast %dim_4762 : index to i64 + %from_elements_4763 = tensor.from_elements %c1_i64, %13497, %c4096_i64, %c1_i64 : tensor<4xi64> + %13498 = stablehlo.dynamic_reshape %13496, %from_elements_4763 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13499 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4759, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4764 = tensor.dim %13499, %c1 : tensor<1x?x4096xi64> + %13500 = arith.index_cast %dim_4764 : index to i64 + %from_elements_4765 = tensor.from_elements %c1_i64, %13500, %c4096_i64, %c1_i64 : tensor<4xi64> + %13501 = stablehlo.dynamic_reshape %13499, %from_elements_4765 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13502 = stablehlo.concatenate %13495, %13498, %13501, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13503 = "stablehlo.gather"(%13047, %13502) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13504 = shape.shape_of %13503 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13505 = shape.num_elements %13504 : tensor<3xindex> -> index + %13506 = stablehlo.compute_reshape_shape %13505, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13507 = stablehlo.dynamic_reshape %13503, %13506 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13508 = stablehlo.dot %13507, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13509 = stablehlo.logistic %13508 : tensor + %13510 = shape.shape_of %13509 : tensor -> tensor<2xindex> + %13511 = shape.shape_of %13508 : tensor -> tensor<2xindex> + %13512 = shape.cstr_broadcastable %13510, %13511 : tensor<2xindex>, tensor<2xindex> + %13513 = shape.assuming %13512 -> (tensor) { + %19688 = shape.broadcast %13510, %13511 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13509, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13508, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13514 = shape.shape_of %13513 : tensor -> tensor<2xindex> + %13515 = shape.cstr_broadcastable %13514, %13511 : tensor<2xindex>, tensor<2xindex> + %13516 = shape.assuming %13515 -> (tensor) { + %19688 = shape.broadcast %13514, %13511 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13513, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13508, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13517 = stablehlo.dot %13516, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4766 = tensor.dim %13489, %c0 : tensor + %13518 = arith.index_cast %dim_4766 : index to i64 + %from_elements_4767 = tensor.from_elements %13518, %c1_i64 : tensor<2xi64> + %13519 = stablehlo.dynamic_reshape %13489, %from_elements_4767 : (tensor, tensor<2xi64>) -> tensor + %dim_4768 = tensor.dim %13486, %c0 : tensor + %13520 = arith.index_cast %dim_4768 : index to i64 + %from_elements_4769 = tensor.from_elements %13520, %c1_i64 : tensor<2xi64> + %13521 = stablehlo.dynamic_reshape %13486, %from_elements_4769 : (tensor, tensor<2xi64>) -> tensor + %13522 = stablehlo.concatenate %13519, %13521, dim = 1 : (tensor, tensor) -> tensor + %13523 = "stablehlo.gather"(%13076, %13522) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13524 = shape.shape_of %13517 : tensor -> tensor<2xindex> + %13525 = shape.shape_of %13523 : tensor -> tensor<2xindex> + %13526 = shape.cstr_broadcastable %13524, %13525 : tensor<2xindex>, tensor<2xindex> + %13527 = shape.assuming %13526 -> (tensor) { + %19688 = shape.broadcast %13524, %13525 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13517, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13523, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13528 = shape.shape_of %13527 : tensor -> tensor<2xindex> + %13529 = stablehlo.dynamic_broadcast_in_dim %13527, %13528, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13530 = stablehlo.dynamic_broadcast_in_dim %213, %13528, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13531 = stablehlo.multiply %13529, %13530 : tensor + %dim_4770 = tensor.dim %13491, %c0 : tensor + %13532 = arith.index_cast %dim_4770 : index to i64 + %dim_4771 = tensor.dim %13527, %c0 : tensor + %13533 = arith.index_cast %dim_4771 : index to i64 + %13534 = arith.maxsi %13532, %13533 : i64 + %13535 = arith.index_cast %13534 : i64 to index + %from_elements_4772 = tensor.from_elements %13535, %c4096 : tensor<2xindex> + %13536 = stablehlo.dynamic_broadcast_in_dim %13491, %from_elements_4772, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4773 = tensor.dim %13536, %c0 : tensor + %13537 = arith.index_cast %dim_4773 : index to i64 + %from_elements_4774 = tensor.from_elements %13537, %c4096_i64 : tensor<2xi64> + %13538 = stablehlo.real_dynamic_slice %13531, %c_22, %from_elements_4774, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4775 = tensor.from_elements %13537, %c4096_i64, %c1_i64 : tensor<3xi64> + %13539 = stablehlo.dynamic_reshape %13536, %from_elements_4775 : (tensor, tensor<3xi64>) -> tensor + %13540 = stablehlo.dynamic_iota %from_elements_4775, dim = 1 : (tensor<3xi64>) -> tensor + %13541 = stablehlo.concatenate %13539, %13540, dim = 2 : (tensor, tensor) -> tensor + %13542 = "stablehlo.scatter"(%13479, %13541, %13538) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13543 = stablehlo.reshape %13542 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %13544 = stablehlo.add %13009, %13543 : tensor<3x1x4096xf32> + %13545 = stablehlo.broadcast_in_dim %13544, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %13546 = stablehlo.power %13545, %15 : tensor<3x1x4096xf32> + %13547 = stablehlo.reduce(%13546 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %13548 = stablehlo.reshape %13547 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %13549 = stablehlo.broadcast_in_dim %13548, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %13550 = stablehlo.divide %13549, %21 : tensor<3x1x1xf32> + %13551 = stablehlo.broadcast_in_dim %13550, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %13552 = stablehlo.add %13551, %25 : tensor<3x1x1xf32> + %13553 = stablehlo.rsqrt %13552 : tensor<3x1x1xf32> + %13554 = stablehlo.broadcast_in_dim %13553, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %13555 = stablehlo.multiply %13545, %13554 : tensor<3x1x4096xf32> + %13556 = stablehlo.broadcast_in_dim %13555, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %13557 = stablehlo.multiply %13556, %31 : tensor<3x1x4096xf32> + %13558 = stablehlo.reshape %13557 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %13559 = stablehlo.dot %13558, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %13560 = stablehlo.reshape %13559 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %13561 = stablehlo.dot %13558, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %13562 = stablehlo.reshape %13561 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %13563 = stablehlo.reshape %13560 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %13564 = stablehlo.transpose %13563, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %13565 = stablehlo.reshape %13562 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %13566 = stablehlo.transpose %13565, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %13567 = stablehlo.slice %arg44 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %13568 = stablehlo.slice %arg45 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %13569 = "stablehlo.gather"(%13567, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %13570 = stablehlo.reshape %13569 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %13571 = "stablehlo.gather"(%13568, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %13572 = stablehlo.reshape %13571 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %13573 = stablehlo.broadcast_in_dim %13564, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %13574 = stablehlo.broadcast_in_dim %13570, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %13575 = stablehlo.multiply %13573, %13574 : tensor<3x32x1x128xf32> + %13576 = stablehlo.slice %13564 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %13577 = stablehlo.slice %13564 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %13578 = stablehlo.negate %13577 : tensor<3x32x1x64xf32> + %13579 = stablehlo.concatenate %13578, %13576, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %13580 = stablehlo.broadcast_in_dim %13579, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %13581 = stablehlo.broadcast_in_dim %13572, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %13582 = stablehlo.multiply %13580, %13581 : tensor<3x32x1x128xf32> + %13583 = stablehlo.add %13575, %13582 : tensor<3x32x1x128xf32> + %13584 = stablehlo.broadcast_in_dim %13566, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %13585 = stablehlo.broadcast_in_dim %13570, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %13586 = stablehlo.multiply %13584, %13585 : tensor<3x8x1x128xf32> + %13587 = stablehlo.slice %13566 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %13588 = stablehlo.slice %13566 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %13589 = stablehlo.negate %13588 : tensor<3x8x1x64xf32> + %13590 = stablehlo.concatenate %13589, %13587, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %13591 = stablehlo.broadcast_in_dim %13590, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %13592 = stablehlo.broadcast_in_dim %13572, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %13593 = stablehlo.multiply %13591, %13592 : tensor<3x8x1x128xf32> + %13594 = stablehlo.add %13586, %13593 : tensor<3x8x1x128xf32> + %13595 = stablehlo.concatenate %arg109, %13594, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %13596 = stablehlo.concatenate %arg110, %13566, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %13597 = stablehlo.reshape %13595 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %13598 = stablehlo.broadcast_in_dim %13597, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %13599 = stablehlo.reshape %13598 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %13600 = stablehlo.reshape %13596 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %13601 = stablehlo.broadcast_in_dim %13600, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %13602 = stablehlo.reshape %13601 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %13603 = stablehlo.transpose %13599, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %13604 = stablehlo.reshape %13583 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %13605 = stablehlo.reshape %13603 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %13606 = stablehlo.broadcast_in_dim %13605, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %13607 = stablehlo.dot_general %13604, %13606, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %13608 = stablehlo.reshape %13607 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %13609 = stablehlo.broadcast_in_dim %13608, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %13610 = stablehlo.divide %13609, %89 : tensor<3x32x1x8xf32> + %13611 = stablehlo.custom_call @byteir.softmax(%13610) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %13612 = stablehlo.reshape %13611 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %13613 = stablehlo.reshape %13602 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %13614 = stablehlo.broadcast_in_dim %13613, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %13615 = stablehlo.dot_general %13612, %13614, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %13616 = stablehlo.reshape %13615 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %13617 = stablehlo.transpose %13616, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %13618 = stablehlo.reshape %13617 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %13619 = stablehlo.reshape %13618 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %13620 = stablehlo.dot %13619, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %13621 = stablehlo.reshape %13620 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %13622 = stablehlo.add %13544, %13621 : tensor<3x1x4096xf32> + %13623 = stablehlo.broadcast_in_dim %13622, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %13624 = stablehlo.power %13623, %15 : tensor<3x1x4096xf32> + %13625 = stablehlo.reduce(%13624 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %13626 = stablehlo.reshape %13625 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %13627 = stablehlo.broadcast_in_dim %13626, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %13628 = stablehlo.divide %13627, %21 : tensor<3x1x1xf32> + %13629 = stablehlo.broadcast_in_dim %13628, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %13630 = stablehlo.add %13629, %25 : tensor<3x1x1xf32> + %13631 = stablehlo.rsqrt %13630 : tensor<3x1x1xf32> + %13632 = stablehlo.broadcast_in_dim %13631, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %13633 = stablehlo.multiply %13623, %13632 : tensor<3x1x4096xf32> + %13634 = stablehlo.broadcast_in_dim %13633, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %13635 = stablehlo.multiply %13634, %31 : tensor<3x1x4096xf32> + %13636 = stablehlo.reshape %13635 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %13637 = stablehlo.dot %13636, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %13638 = stablehlo.custom_call @byteir.softmax(%13637) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %13639:2 = stablehlo.custom_call @byteir.top_k(%13638) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %13640 = stablehlo.reduce(%13639#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %13641 = stablehlo.reshape %13640 : (tensor<3xf32>) -> tensor<3x1xf32> + %13642 = stablehlo.broadcast_in_dim %13639#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %13643 = stablehlo.broadcast_in_dim %13641, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %13644 = stablehlo.divide %13642, %13643 : tensor<3x2xf32> + %13645 = stablehlo.reshape %13639#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %13646 = stablehlo.broadcast_in_dim %13645, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %13647 = stablehlo.compare EQ, %13646, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %13648 = stablehlo.convert %13647 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %13649 = stablehlo.transpose %13648, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %13650 = stablehlo.slice %13649 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13651 = stablehlo.reshape %13650 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13652 = stablehlo.custom_call @byteir.non_zero(%13651) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4776 = tensor.dim %13652, %c0 : tensor + %13653 = arith.index_cast %dim_4776 : index to i64 + %from_elements_4777 = tensor.from_elements %13653, %c1_i64 : tensor<2xi64> + %13654 = stablehlo.real_dynamic_slice %13652, %c_22, %from_elements_4777, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4778 = tensor.dim %13654, %c0 : tensor + %13655 = arith.index_cast %dim_4778 : index to i64 + %from_elements_4779 = tensor.from_elements %13655 : tensor<1xi64> + %13656 = stablehlo.dynamic_reshape %13654, %from_elements_4779 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4780 = tensor.from_elements %13653, %c2_i64 : tensor<2xi64> + %13657 = stablehlo.real_dynamic_slice %13652, %c_24, %from_elements_4780, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4781 = tensor.dim %13657, %c0 : tensor + %13658 = arith.index_cast %dim_4781 : index to i64 + %from_elements_4782 = tensor.from_elements %13658 : tensor<1xi64> + %13659 = stablehlo.dynamic_reshape %13657, %from_elements_4782 : (tensor, tensor<1xi64>) -> tensor + %13660 = stablehlo.reshape %13636 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_4783 = tensor.dim %13659, %c0 : tensor + %13661 = arith.index_cast %dim_4783 : index to i64 + %from_elements_4784 = tensor.from_elements %13661, %c1_i64 : tensor<2xi64> + %13662 = stablehlo.dynamic_reshape %13659, %from_elements_4784 : (tensor, tensor<2xi64>) -> tensor + %dim_4785 = tensor.dim %13662, %c0 : tensor + %13663 = arith.index_cast %dim_4785 : index to i64 + %from_elements_4786 = tensor.from_elements %c1_i64, %13663, %c4096_i64 : tensor<3xi64> + %13664 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4786, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4787 = tensor.dim %13664, %c1 : tensor<1x?x4096xi64> + %13665 = arith.index_cast %dim_4787 : index to i64 + %from_elements_4788 = tensor.from_elements %c1_i64, %13665, %c4096_i64, %c1_i64 : tensor<4xi64> + %13666 = stablehlo.dynamic_reshape %13664, %from_elements_4788 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13667 = stablehlo.dynamic_broadcast_in_dim %13662, %from_elements_4786, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4789 = tensor.dim %13667, %c1 : tensor<1x?x4096xi64> + %13668 = arith.index_cast %dim_4789 : index to i64 + %from_elements_4790 = tensor.from_elements %c1_i64, %13668, %c4096_i64, %c1_i64 : tensor<4xi64> + %13669 = stablehlo.dynamic_reshape %13667, %from_elements_4790 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13670 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4786, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4791 = tensor.dim %13670, %c1 : tensor<1x?x4096xi64> + %13671 = arith.index_cast %dim_4791 : index to i64 + %from_elements_4792 = tensor.from_elements %c1_i64, %13671, %c4096_i64, %c1_i64 : tensor<4xi64> + %13672 = stablehlo.dynamic_reshape %13670, %from_elements_4792 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13673 = stablehlo.concatenate %13666, %13669, %13672, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13674 = "stablehlo.gather"(%13660, %13673) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13675 = shape.shape_of %13674 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13676 = shape.num_elements %13675 : tensor<3xindex> -> index + %13677 = stablehlo.compute_reshape_shape %13676, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13678 = stablehlo.dynamic_reshape %13674, %13677 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13679 = stablehlo.dot %13678, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13680 = stablehlo.logistic %13679 : tensor + %13681 = shape.shape_of %13680 : tensor -> tensor<2xindex> + %13682 = shape.shape_of %13679 : tensor -> tensor<2xindex> + %13683 = shape.cstr_broadcastable %13681, %13682 : tensor<2xindex>, tensor<2xindex> + %13684 = shape.assuming %13683 -> (tensor) { + %19688 = shape.broadcast %13681, %13682 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13680, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13679, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13685 = shape.shape_of %13684 : tensor -> tensor<2xindex> + %13686 = shape.cstr_broadcastable %13685, %13682 : tensor<2xindex>, tensor<2xindex> + %13687 = shape.assuming %13686 -> (tensor) { + %19688 = shape.broadcast %13685, %13682 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13684, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13679, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13688 = stablehlo.dot %13687, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %13689 = stablehlo.reshape %13644 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_4793 = tensor.dim %13659, %c0 : tensor + %13690 = arith.index_cast %dim_4793 : index to i64 + %from_elements_4794 = tensor.from_elements %13690, %c1_i64 : tensor<2xi64> + %13691 = stablehlo.dynamic_reshape %13659, %from_elements_4794 : (tensor, tensor<2xi64>) -> tensor + %dim_4795 = tensor.dim %13656, %c0 : tensor + %13692 = arith.index_cast %dim_4795 : index to i64 + %from_elements_4796 = tensor.from_elements %13692, %c1_i64 : tensor<2xi64> + %13693 = stablehlo.dynamic_reshape %13656, %from_elements_4796 : (tensor, tensor<2xi64>) -> tensor + %13694 = stablehlo.concatenate %13691, %13693, dim = 1 : (tensor, tensor) -> tensor + %13695 = "stablehlo.gather"(%13689, %13694) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13696 = shape.shape_of %13688 : tensor -> tensor<2xindex> + %13697 = shape.shape_of %13695 : tensor -> tensor<2xindex> + %13698 = shape.cstr_broadcastable %13696, %13697 : tensor<2xindex>, tensor<2xindex> + %13699 = shape.assuming %13698 -> (tensor) { + %19688 = shape.broadcast %13696, %13697 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13688, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13695, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13700 = shape.shape_of %13699 : tensor -> tensor<2xindex> + %13701 = stablehlo.dynamic_broadcast_in_dim %13699, %13700, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13702 = stablehlo.dynamic_broadcast_in_dim %213, %13700, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13703 = stablehlo.multiply %13701, %13702 : tensor + %dim_4797 = tensor.dim %13662, %c0 : tensor + %13704 = arith.index_cast %dim_4797 : index to i64 + %dim_4798 = tensor.dim %13699, %c0 : tensor + %13705 = arith.index_cast %dim_4798 : index to i64 + %13706 = arith.maxsi %13704, %13705 : i64 + %13707 = arith.index_cast %13706 : i64 to index + %from_elements_4799 = tensor.from_elements %13707, %c4096 : tensor<2xindex> + %13708 = stablehlo.dynamic_broadcast_in_dim %13662, %from_elements_4799, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4800 = tensor.dim %13708, %c0 : tensor + %13709 = arith.index_cast %dim_4800 : index to i64 + %from_elements_4801 = tensor.from_elements %13709, %c4096_i64 : tensor<2xi64> + %13710 = stablehlo.real_dynamic_slice %13703, %c_22, %from_elements_4801, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4802 = tensor.from_elements %13709, %c4096_i64, %c1_i64 : tensor<3xi64> + %13711 = stablehlo.dynamic_reshape %13708, %from_elements_4802 : (tensor, tensor<3xi64>) -> tensor + %13712 = stablehlo.dynamic_iota %from_elements_4802, dim = 1 : (tensor<3xi64>) -> tensor + %13713 = stablehlo.concatenate %13711, %13712, dim = 2 : (tensor, tensor) -> tensor + %13714 = "stablehlo.scatter"(%cst_2, %13713, %13710) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13715 = stablehlo.slice %13649 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13716 = stablehlo.reshape %13715 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13717 = stablehlo.custom_call @byteir.non_zero(%13716) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4803 = tensor.dim %13717, %c0 : tensor + %13718 = arith.index_cast %dim_4803 : index to i64 + %from_elements_4804 = tensor.from_elements %13718, %c1_i64 : tensor<2xi64> + %13719 = stablehlo.real_dynamic_slice %13717, %c_22, %from_elements_4804, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4805 = tensor.dim %13719, %c0 : tensor + %13720 = arith.index_cast %dim_4805 : index to i64 + %from_elements_4806 = tensor.from_elements %13720 : tensor<1xi64> + %13721 = stablehlo.dynamic_reshape %13719, %from_elements_4806 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4807 = tensor.from_elements %13718, %c2_i64 : tensor<2xi64> + %13722 = stablehlo.real_dynamic_slice %13717, %c_24, %from_elements_4807, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4808 = tensor.dim %13722, %c0 : tensor + %13723 = arith.index_cast %dim_4808 : index to i64 + %from_elements_4809 = tensor.from_elements %13723 : tensor<1xi64> + %13724 = stablehlo.dynamic_reshape %13722, %from_elements_4809 : (tensor, tensor<1xi64>) -> tensor + %dim_4810 = tensor.dim %13724, %c0 : tensor + %13725 = arith.index_cast %dim_4810 : index to i64 + %from_elements_4811 = tensor.from_elements %13725, %c1_i64 : tensor<2xi64> + %13726 = stablehlo.dynamic_reshape %13724, %from_elements_4811 : (tensor, tensor<2xi64>) -> tensor + %dim_4812 = tensor.dim %13726, %c0 : tensor + %13727 = arith.index_cast %dim_4812 : index to i64 + %from_elements_4813 = tensor.from_elements %c1_i64, %13727, %c4096_i64 : tensor<3xi64> + %13728 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4813, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4814 = tensor.dim %13728, %c1 : tensor<1x?x4096xi64> + %13729 = arith.index_cast %dim_4814 : index to i64 + %from_elements_4815 = tensor.from_elements %c1_i64, %13729, %c4096_i64, %c1_i64 : tensor<4xi64> + %13730 = stablehlo.dynamic_reshape %13728, %from_elements_4815 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13731 = stablehlo.dynamic_broadcast_in_dim %13726, %from_elements_4813, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4816 = tensor.dim %13731, %c1 : tensor<1x?x4096xi64> + %13732 = arith.index_cast %dim_4816 : index to i64 + %from_elements_4817 = tensor.from_elements %c1_i64, %13732, %c4096_i64, %c1_i64 : tensor<4xi64> + %13733 = stablehlo.dynamic_reshape %13731, %from_elements_4817 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13734 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4813, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4818 = tensor.dim %13734, %c1 : tensor<1x?x4096xi64> + %13735 = arith.index_cast %dim_4818 : index to i64 + %from_elements_4819 = tensor.from_elements %c1_i64, %13735, %c4096_i64, %c1_i64 : tensor<4xi64> + %13736 = stablehlo.dynamic_reshape %13734, %from_elements_4819 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13737 = stablehlo.concatenate %13730, %13733, %13736, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13738 = "stablehlo.gather"(%13660, %13737) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13739 = shape.shape_of %13738 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13740 = shape.num_elements %13739 : tensor<3xindex> -> index + %13741 = stablehlo.compute_reshape_shape %13740, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13742 = stablehlo.dynamic_reshape %13738, %13741 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13743 = stablehlo.dot %13742, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13744 = stablehlo.logistic %13743 : tensor + %13745 = shape.shape_of %13744 : tensor -> tensor<2xindex> + %13746 = shape.shape_of %13743 : tensor -> tensor<2xindex> + %13747 = shape.cstr_broadcastable %13745, %13746 : tensor<2xindex>, tensor<2xindex> + %13748 = shape.assuming %13747 -> (tensor) { + %19688 = shape.broadcast %13745, %13746 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13744, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13743, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13749 = shape.shape_of %13748 : tensor -> tensor<2xindex> + %13750 = shape.cstr_broadcastable %13749, %13746 : tensor<2xindex>, tensor<2xindex> + %13751 = shape.assuming %13750 -> (tensor) { + %19688 = shape.broadcast %13749, %13746 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13748, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13743, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13752 = stablehlo.dot %13751, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4820 = tensor.dim %13724, %c0 : tensor + %13753 = arith.index_cast %dim_4820 : index to i64 + %from_elements_4821 = tensor.from_elements %13753, %c1_i64 : tensor<2xi64> + %13754 = stablehlo.dynamic_reshape %13724, %from_elements_4821 : (tensor, tensor<2xi64>) -> tensor + %dim_4822 = tensor.dim %13721, %c0 : tensor + %13755 = arith.index_cast %dim_4822 : index to i64 + %from_elements_4823 = tensor.from_elements %13755, %c1_i64 : tensor<2xi64> + %13756 = stablehlo.dynamic_reshape %13721, %from_elements_4823 : (tensor, tensor<2xi64>) -> tensor + %13757 = stablehlo.concatenate %13754, %13756, dim = 1 : (tensor, tensor) -> tensor + %13758 = "stablehlo.gather"(%13689, %13757) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13759 = shape.shape_of %13752 : tensor -> tensor<2xindex> + %13760 = shape.shape_of %13758 : tensor -> tensor<2xindex> + %13761 = shape.cstr_broadcastable %13759, %13760 : tensor<2xindex>, tensor<2xindex> + %13762 = shape.assuming %13761 -> (tensor) { + %19688 = shape.broadcast %13759, %13760 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13752, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13758, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13763 = shape.shape_of %13762 : tensor -> tensor<2xindex> + %13764 = stablehlo.dynamic_broadcast_in_dim %13762, %13763, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13765 = stablehlo.dynamic_broadcast_in_dim %213, %13763, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13766 = stablehlo.multiply %13764, %13765 : tensor + %dim_4824 = tensor.dim %13726, %c0 : tensor + %13767 = arith.index_cast %dim_4824 : index to i64 + %dim_4825 = tensor.dim %13762, %c0 : tensor + %13768 = arith.index_cast %dim_4825 : index to i64 + %13769 = arith.maxsi %13767, %13768 : i64 + %13770 = arith.index_cast %13769 : i64 to index + %from_elements_4826 = tensor.from_elements %13770, %c4096 : tensor<2xindex> + %13771 = stablehlo.dynamic_broadcast_in_dim %13726, %from_elements_4826, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4827 = tensor.dim %13771, %c0 : tensor + %13772 = arith.index_cast %dim_4827 : index to i64 + %from_elements_4828 = tensor.from_elements %13772, %c4096_i64 : tensor<2xi64> + %13773 = stablehlo.real_dynamic_slice %13766, %c_22, %from_elements_4828, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4829 = tensor.from_elements %13772, %c4096_i64, %c1_i64 : tensor<3xi64> + %13774 = stablehlo.dynamic_reshape %13771, %from_elements_4829 : (tensor, tensor<3xi64>) -> tensor + %13775 = stablehlo.dynamic_iota %from_elements_4829, dim = 1 : (tensor<3xi64>) -> tensor + %13776 = stablehlo.concatenate %13774, %13775, dim = 2 : (tensor, tensor) -> tensor + %13777 = "stablehlo.scatter"(%13714, %13776, %13773) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13778 = stablehlo.slice %13649 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13779 = stablehlo.reshape %13778 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13780 = stablehlo.custom_call @byteir.non_zero(%13779) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4830 = tensor.dim %13780, %c0 : tensor + %13781 = arith.index_cast %dim_4830 : index to i64 + %from_elements_4831 = tensor.from_elements %13781, %c1_i64 : tensor<2xi64> + %13782 = stablehlo.real_dynamic_slice %13780, %c_22, %from_elements_4831, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4832 = tensor.dim %13782, %c0 : tensor + %13783 = arith.index_cast %dim_4832 : index to i64 + %from_elements_4833 = tensor.from_elements %13783 : tensor<1xi64> + %13784 = stablehlo.dynamic_reshape %13782, %from_elements_4833 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4834 = tensor.from_elements %13781, %c2_i64 : tensor<2xi64> + %13785 = stablehlo.real_dynamic_slice %13780, %c_24, %from_elements_4834, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4835 = tensor.dim %13785, %c0 : tensor + %13786 = arith.index_cast %dim_4835 : index to i64 + %from_elements_4836 = tensor.from_elements %13786 : tensor<1xi64> + %13787 = stablehlo.dynamic_reshape %13785, %from_elements_4836 : (tensor, tensor<1xi64>) -> tensor + %dim_4837 = tensor.dim %13787, %c0 : tensor + %13788 = arith.index_cast %dim_4837 : index to i64 + %from_elements_4838 = tensor.from_elements %13788, %c1_i64 : tensor<2xi64> + %13789 = stablehlo.dynamic_reshape %13787, %from_elements_4838 : (tensor, tensor<2xi64>) -> tensor + %dim_4839 = tensor.dim %13789, %c0 : tensor + %13790 = arith.index_cast %dim_4839 : index to i64 + %from_elements_4840 = tensor.from_elements %c1_i64, %13790, %c4096_i64 : tensor<3xi64> + %13791 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4840, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4841 = tensor.dim %13791, %c1 : tensor<1x?x4096xi64> + %13792 = arith.index_cast %dim_4841 : index to i64 + %from_elements_4842 = tensor.from_elements %c1_i64, %13792, %c4096_i64, %c1_i64 : tensor<4xi64> + %13793 = stablehlo.dynamic_reshape %13791, %from_elements_4842 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13794 = stablehlo.dynamic_broadcast_in_dim %13789, %from_elements_4840, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4843 = tensor.dim %13794, %c1 : tensor<1x?x4096xi64> + %13795 = arith.index_cast %dim_4843 : index to i64 + %from_elements_4844 = tensor.from_elements %c1_i64, %13795, %c4096_i64, %c1_i64 : tensor<4xi64> + %13796 = stablehlo.dynamic_reshape %13794, %from_elements_4844 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13797 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4840, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4845 = tensor.dim %13797, %c1 : tensor<1x?x4096xi64> + %13798 = arith.index_cast %dim_4845 : index to i64 + %from_elements_4846 = tensor.from_elements %c1_i64, %13798, %c4096_i64, %c1_i64 : tensor<4xi64> + %13799 = stablehlo.dynamic_reshape %13797, %from_elements_4846 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13800 = stablehlo.concatenate %13793, %13796, %13799, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13801 = "stablehlo.gather"(%13660, %13800) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13802 = shape.shape_of %13801 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13803 = shape.num_elements %13802 : tensor<3xindex> -> index + %13804 = stablehlo.compute_reshape_shape %13803, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13805 = stablehlo.dynamic_reshape %13801, %13804 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13806 = stablehlo.dot %13805, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13807 = stablehlo.logistic %13806 : tensor + %13808 = shape.shape_of %13807 : tensor -> tensor<2xindex> + %13809 = shape.shape_of %13806 : tensor -> tensor<2xindex> + %13810 = shape.cstr_broadcastable %13808, %13809 : tensor<2xindex>, tensor<2xindex> + %13811 = shape.assuming %13810 -> (tensor) { + %19688 = shape.broadcast %13808, %13809 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13807, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13806, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13812 = shape.shape_of %13811 : tensor -> tensor<2xindex> + %13813 = shape.cstr_broadcastable %13812, %13809 : tensor<2xindex>, tensor<2xindex> + %13814 = shape.assuming %13813 -> (tensor) { + %19688 = shape.broadcast %13812, %13809 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13811, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13806, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13815 = stablehlo.dot %13814, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4847 = tensor.dim %13787, %c0 : tensor + %13816 = arith.index_cast %dim_4847 : index to i64 + %from_elements_4848 = tensor.from_elements %13816, %c1_i64 : tensor<2xi64> + %13817 = stablehlo.dynamic_reshape %13787, %from_elements_4848 : (tensor, tensor<2xi64>) -> tensor + %dim_4849 = tensor.dim %13784, %c0 : tensor + %13818 = arith.index_cast %dim_4849 : index to i64 + %from_elements_4850 = tensor.from_elements %13818, %c1_i64 : tensor<2xi64> + %13819 = stablehlo.dynamic_reshape %13784, %from_elements_4850 : (tensor, tensor<2xi64>) -> tensor + %13820 = stablehlo.concatenate %13817, %13819, dim = 1 : (tensor, tensor) -> tensor + %13821 = "stablehlo.gather"(%13689, %13820) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13822 = shape.shape_of %13815 : tensor -> tensor<2xindex> + %13823 = shape.shape_of %13821 : tensor -> tensor<2xindex> + %13824 = shape.cstr_broadcastable %13822, %13823 : tensor<2xindex>, tensor<2xindex> + %13825 = shape.assuming %13824 -> (tensor) { + %19688 = shape.broadcast %13822, %13823 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13815, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13821, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13826 = shape.shape_of %13825 : tensor -> tensor<2xindex> + %13827 = stablehlo.dynamic_broadcast_in_dim %13825, %13826, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13828 = stablehlo.dynamic_broadcast_in_dim %213, %13826, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13829 = stablehlo.multiply %13827, %13828 : tensor + %dim_4851 = tensor.dim %13789, %c0 : tensor + %13830 = arith.index_cast %dim_4851 : index to i64 + %dim_4852 = tensor.dim %13825, %c0 : tensor + %13831 = arith.index_cast %dim_4852 : index to i64 + %13832 = arith.maxsi %13830, %13831 : i64 + %13833 = arith.index_cast %13832 : i64 to index + %from_elements_4853 = tensor.from_elements %13833, %c4096 : tensor<2xindex> + %13834 = stablehlo.dynamic_broadcast_in_dim %13789, %from_elements_4853, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4854 = tensor.dim %13834, %c0 : tensor + %13835 = arith.index_cast %dim_4854 : index to i64 + %from_elements_4855 = tensor.from_elements %13835, %c4096_i64 : tensor<2xi64> + %13836 = stablehlo.real_dynamic_slice %13829, %c_22, %from_elements_4855, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4856 = tensor.from_elements %13835, %c4096_i64, %c1_i64 : tensor<3xi64> + %13837 = stablehlo.dynamic_reshape %13834, %from_elements_4856 : (tensor, tensor<3xi64>) -> tensor + %13838 = stablehlo.dynamic_iota %from_elements_4856, dim = 1 : (tensor<3xi64>) -> tensor + %13839 = stablehlo.concatenate %13837, %13838, dim = 2 : (tensor, tensor) -> tensor + %13840 = "stablehlo.scatter"(%13777, %13839, %13836) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13841 = stablehlo.slice %13649 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13842 = stablehlo.reshape %13841 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13843 = stablehlo.custom_call @byteir.non_zero(%13842) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4857 = tensor.dim %13843, %c0 : tensor + %13844 = arith.index_cast %dim_4857 : index to i64 + %from_elements_4858 = tensor.from_elements %13844, %c1_i64 : tensor<2xi64> + %13845 = stablehlo.real_dynamic_slice %13843, %c_22, %from_elements_4858, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4859 = tensor.dim %13845, %c0 : tensor + %13846 = arith.index_cast %dim_4859 : index to i64 + %from_elements_4860 = tensor.from_elements %13846 : tensor<1xi64> + %13847 = stablehlo.dynamic_reshape %13845, %from_elements_4860 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4861 = tensor.from_elements %13844, %c2_i64 : tensor<2xi64> + %13848 = stablehlo.real_dynamic_slice %13843, %c_24, %from_elements_4861, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4862 = tensor.dim %13848, %c0 : tensor + %13849 = arith.index_cast %dim_4862 : index to i64 + %from_elements_4863 = tensor.from_elements %13849 : tensor<1xi64> + %13850 = stablehlo.dynamic_reshape %13848, %from_elements_4863 : (tensor, tensor<1xi64>) -> tensor + %dim_4864 = tensor.dim %13850, %c0 : tensor + %13851 = arith.index_cast %dim_4864 : index to i64 + %from_elements_4865 = tensor.from_elements %13851, %c1_i64 : tensor<2xi64> + %13852 = stablehlo.dynamic_reshape %13850, %from_elements_4865 : (tensor, tensor<2xi64>) -> tensor + %dim_4866 = tensor.dim %13852, %c0 : tensor + %13853 = arith.index_cast %dim_4866 : index to i64 + %from_elements_4867 = tensor.from_elements %c1_i64, %13853, %c4096_i64 : tensor<3xi64> + %13854 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4867, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4868 = tensor.dim %13854, %c1 : tensor<1x?x4096xi64> + %13855 = arith.index_cast %dim_4868 : index to i64 + %from_elements_4869 = tensor.from_elements %c1_i64, %13855, %c4096_i64, %c1_i64 : tensor<4xi64> + %13856 = stablehlo.dynamic_reshape %13854, %from_elements_4869 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13857 = stablehlo.dynamic_broadcast_in_dim %13852, %from_elements_4867, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4870 = tensor.dim %13857, %c1 : tensor<1x?x4096xi64> + %13858 = arith.index_cast %dim_4870 : index to i64 + %from_elements_4871 = tensor.from_elements %c1_i64, %13858, %c4096_i64, %c1_i64 : tensor<4xi64> + %13859 = stablehlo.dynamic_reshape %13857, %from_elements_4871 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13860 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4867, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4872 = tensor.dim %13860, %c1 : tensor<1x?x4096xi64> + %13861 = arith.index_cast %dim_4872 : index to i64 + %from_elements_4873 = tensor.from_elements %c1_i64, %13861, %c4096_i64, %c1_i64 : tensor<4xi64> + %13862 = stablehlo.dynamic_reshape %13860, %from_elements_4873 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13863 = stablehlo.concatenate %13856, %13859, %13862, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13864 = "stablehlo.gather"(%13660, %13863) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13865 = shape.shape_of %13864 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13866 = shape.num_elements %13865 : tensor<3xindex> -> index + %13867 = stablehlo.compute_reshape_shape %13866, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13868 = stablehlo.dynamic_reshape %13864, %13867 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13869 = stablehlo.dot %13868, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13870 = stablehlo.logistic %13869 : tensor + %13871 = shape.shape_of %13870 : tensor -> tensor<2xindex> + %13872 = shape.shape_of %13869 : tensor -> tensor<2xindex> + %13873 = shape.cstr_broadcastable %13871, %13872 : tensor<2xindex>, tensor<2xindex> + %13874 = shape.assuming %13873 -> (tensor) { + %19688 = shape.broadcast %13871, %13872 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13870, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13869, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13875 = shape.shape_of %13874 : tensor -> tensor<2xindex> + %13876 = shape.cstr_broadcastable %13875, %13872 : tensor<2xindex>, tensor<2xindex> + %13877 = shape.assuming %13876 -> (tensor) { + %19688 = shape.broadcast %13875, %13872 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13874, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13869, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13878 = stablehlo.dot %13877, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4874 = tensor.dim %13850, %c0 : tensor + %13879 = arith.index_cast %dim_4874 : index to i64 + %from_elements_4875 = tensor.from_elements %13879, %c1_i64 : tensor<2xi64> + %13880 = stablehlo.dynamic_reshape %13850, %from_elements_4875 : (tensor, tensor<2xi64>) -> tensor + %dim_4876 = tensor.dim %13847, %c0 : tensor + %13881 = arith.index_cast %dim_4876 : index to i64 + %from_elements_4877 = tensor.from_elements %13881, %c1_i64 : tensor<2xi64> + %13882 = stablehlo.dynamic_reshape %13847, %from_elements_4877 : (tensor, tensor<2xi64>) -> tensor + %13883 = stablehlo.concatenate %13880, %13882, dim = 1 : (tensor, tensor) -> tensor + %13884 = "stablehlo.gather"(%13689, %13883) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13885 = shape.shape_of %13878 : tensor -> tensor<2xindex> + %13886 = shape.shape_of %13884 : tensor -> tensor<2xindex> + %13887 = shape.cstr_broadcastable %13885, %13886 : tensor<2xindex>, tensor<2xindex> + %13888 = shape.assuming %13887 -> (tensor) { + %19688 = shape.broadcast %13885, %13886 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13878, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13884, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13889 = shape.shape_of %13888 : tensor -> tensor<2xindex> + %13890 = stablehlo.dynamic_broadcast_in_dim %13888, %13889, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13891 = stablehlo.dynamic_broadcast_in_dim %213, %13889, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13892 = stablehlo.multiply %13890, %13891 : tensor + %dim_4878 = tensor.dim %13852, %c0 : tensor + %13893 = arith.index_cast %dim_4878 : index to i64 + %dim_4879 = tensor.dim %13888, %c0 : tensor + %13894 = arith.index_cast %dim_4879 : index to i64 + %13895 = arith.maxsi %13893, %13894 : i64 + %13896 = arith.index_cast %13895 : i64 to index + %from_elements_4880 = tensor.from_elements %13896, %c4096 : tensor<2xindex> + %13897 = stablehlo.dynamic_broadcast_in_dim %13852, %from_elements_4880, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4881 = tensor.dim %13897, %c0 : tensor + %13898 = arith.index_cast %dim_4881 : index to i64 + %from_elements_4882 = tensor.from_elements %13898, %c4096_i64 : tensor<2xi64> + %13899 = stablehlo.real_dynamic_slice %13892, %c_22, %from_elements_4882, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4883 = tensor.from_elements %13898, %c4096_i64, %c1_i64 : tensor<3xi64> + %13900 = stablehlo.dynamic_reshape %13897, %from_elements_4883 : (tensor, tensor<3xi64>) -> tensor + %13901 = stablehlo.dynamic_iota %from_elements_4883, dim = 1 : (tensor<3xi64>) -> tensor + %13902 = stablehlo.concatenate %13900, %13901, dim = 2 : (tensor, tensor) -> tensor + %13903 = "stablehlo.scatter"(%13840, %13902, %13899) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13904 = stablehlo.slice %13649 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13905 = stablehlo.reshape %13904 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13906 = stablehlo.custom_call @byteir.non_zero(%13905) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4884 = tensor.dim %13906, %c0 : tensor + %13907 = arith.index_cast %dim_4884 : index to i64 + %from_elements_4885 = tensor.from_elements %13907, %c1_i64 : tensor<2xi64> + %13908 = stablehlo.real_dynamic_slice %13906, %c_22, %from_elements_4885, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4886 = tensor.dim %13908, %c0 : tensor + %13909 = arith.index_cast %dim_4886 : index to i64 + %from_elements_4887 = tensor.from_elements %13909 : tensor<1xi64> + %13910 = stablehlo.dynamic_reshape %13908, %from_elements_4887 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4888 = tensor.from_elements %13907, %c2_i64 : tensor<2xi64> + %13911 = stablehlo.real_dynamic_slice %13906, %c_24, %from_elements_4888, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4889 = tensor.dim %13911, %c0 : tensor + %13912 = arith.index_cast %dim_4889 : index to i64 + %from_elements_4890 = tensor.from_elements %13912 : tensor<1xi64> + %13913 = stablehlo.dynamic_reshape %13911, %from_elements_4890 : (tensor, tensor<1xi64>) -> tensor + %dim_4891 = tensor.dim %13913, %c0 : tensor + %13914 = arith.index_cast %dim_4891 : index to i64 + %from_elements_4892 = tensor.from_elements %13914, %c1_i64 : tensor<2xi64> + %13915 = stablehlo.dynamic_reshape %13913, %from_elements_4892 : (tensor, tensor<2xi64>) -> tensor + %dim_4893 = tensor.dim %13915, %c0 : tensor + %13916 = arith.index_cast %dim_4893 : index to i64 + %from_elements_4894 = tensor.from_elements %c1_i64, %13916, %c4096_i64 : tensor<3xi64> + %13917 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4894, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4895 = tensor.dim %13917, %c1 : tensor<1x?x4096xi64> + %13918 = arith.index_cast %dim_4895 : index to i64 + %from_elements_4896 = tensor.from_elements %c1_i64, %13918, %c4096_i64, %c1_i64 : tensor<4xi64> + %13919 = stablehlo.dynamic_reshape %13917, %from_elements_4896 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13920 = stablehlo.dynamic_broadcast_in_dim %13915, %from_elements_4894, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4897 = tensor.dim %13920, %c1 : tensor<1x?x4096xi64> + %13921 = arith.index_cast %dim_4897 : index to i64 + %from_elements_4898 = tensor.from_elements %c1_i64, %13921, %c4096_i64, %c1_i64 : tensor<4xi64> + %13922 = stablehlo.dynamic_reshape %13920, %from_elements_4898 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13923 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4894, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4899 = tensor.dim %13923, %c1 : tensor<1x?x4096xi64> + %13924 = arith.index_cast %dim_4899 : index to i64 + %from_elements_4900 = tensor.from_elements %c1_i64, %13924, %c4096_i64, %c1_i64 : tensor<4xi64> + %13925 = stablehlo.dynamic_reshape %13923, %from_elements_4900 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13926 = stablehlo.concatenate %13919, %13922, %13925, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13927 = "stablehlo.gather"(%13660, %13926) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13928 = shape.shape_of %13927 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13929 = shape.num_elements %13928 : tensor<3xindex> -> index + %13930 = stablehlo.compute_reshape_shape %13929, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13931 = stablehlo.dynamic_reshape %13927, %13930 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13932 = stablehlo.dot %13931, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13933 = stablehlo.logistic %13932 : tensor + %13934 = shape.shape_of %13933 : tensor -> tensor<2xindex> + %13935 = shape.shape_of %13932 : tensor -> tensor<2xindex> + %13936 = shape.cstr_broadcastable %13934, %13935 : tensor<2xindex>, tensor<2xindex> + %13937 = shape.assuming %13936 -> (tensor) { + %19688 = shape.broadcast %13934, %13935 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13933, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13932, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13938 = shape.shape_of %13937 : tensor -> tensor<2xindex> + %13939 = shape.cstr_broadcastable %13938, %13935 : tensor<2xindex>, tensor<2xindex> + %13940 = shape.assuming %13939 -> (tensor) { + %19688 = shape.broadcast %13938, %13935 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13937, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13932, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13941 = stablehlo.dot %13940, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4901 = tensor.dim %13913, %c0 : tensor + %13942 = arith.index_cast %dim_4901 : index to i64 + %from_elements_4902 = tensor.from_elements %13942, %c1_i64 : tensor<2xi64> + %13943 = stablehlo.dynamic_reshape %13913, %from_elements_4902 : (tensor, tensor<2xi64>) -> tensor + %dim_4903 = tensor.dim %13910, %c0 : tensor + %13944 = arith.index_cast %dim_4903 : index to i64 + %from_elements_4904 = tensor.from_elements %13944, %c1_i64 : tensor<2xi64> + %13945 = stablehlo.dynamic_reshape %13910, %from_elements_4904 : (tensor, tensor<2xi64>) -> tensor + %13946 = stablehlo.concatenate %13943, %13945, dim = 1 : (tensor, tensor) -> tensor + %13947 = "stablehlo.gather"(%13689, %13946) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %13948 = shape.shape_of %13941 : tensor -> tensor<2xindex> + %13949 = shape.shape_of %13947 : tensor -> tensor<2xindex> + %13950 = shape.cstr_broadcastable %13948, %13949 : tensor<2xindex>, tensor<2xindex> + %13951 = shape.assuming %13950 -> (tensor) { + %19688 = shape.broadcast %13948, %13949 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13941, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13947, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %13952 = shape.shape_of %13951 : tensor -> tensor<2xindex> + %13953 = stablehlo.dynamic_broadcast_in_dim %13951, %13952, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %13954 = stablehlo.dynamic_broadcast_in_dim %213, %13952, dims = [] : (tensor, tensor<2xindex>) -> tensor + %13955 = stablehlo.multiply %13953, %13954 : tensor + %dim_4905 = tensor.dim %13915, %c0 : tensor + %13956 = arith.index_cast %dim_4905 : index to i64 + %dim_4906 = tensor.dim %13951, %c0 : tensor + %13957 = arith.index_cast %dim_4906 : index to i64 + %13958 = arith.maxsi %13956, %13957 : i64 + %13959 = arith.index_cast %13958 : i64 to index + %from_elements_4907 = tensor.from_elements %13959, %c4096 : tensor<2xindex> + %13960 = stablehlo.dynamic_broadcast_in_dim %13915, %from_elements_4907, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4908 = tensor.dim %13960, %c0 : tensor + %13961 = arith.index_cast %dim_4908 : index to i64 + %from_elements_4909 = tensor.from_elements %13961, %c4096_i64 : tensor<2xi64> + %13962 = stablehlo.real_dynamic_slice %13955, %c_22, %from_elements_4909, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4910 = tensor.from_elements %13961, %c4096_i64, %c1_i64 : tensor<3xi64> + %13963 = stablehlo.dynamic_reshape %13960, %from_elements_4910 : (tensor, tensor<3xi64>) -> tensor + %13964 = stablehlo.dynamic_iota %from_elements_4910, dim = 1 : (tensor<3xi64>) -> tensor + %13965 = stablehlo.concatenate %13963, %13964, dim = 2 : (tensor, tensor) -> tensor + %13966 = "stablehlo.scatter"(%13903, %13965, %13962) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %13967 = stablehlo.slice %13649 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %13968 = stablehlo.reshape %13967 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %13969 = stablehlo.custom_call @byteir.non_zero(%13968) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4911 = tensor.dim %13969, %c0 : tensor + %13970 = arith.index_cast %dim_4911 : index to i64 + %from_elements_4912 = tensor.from_elements %13970, %c1_i64 : tensor<2xi64> + %13971 = stablehlo.real_dynamic_slice %13969, %c_22, %from_elements_4912, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4913 = tensor.dim %13971, %c0 : tensor + %13972 = arith.index_cast %dim_4913 : index to i64 + %from_elements_4914 = tensor.from_elements %13972 : tensor<1xi64> + %13973 = stablehlo.dynamic_reshape %13971, %from_elements_4914 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4915 = tensor.from_elements %13970, %c2_i64 : tensor<2xi64> + %13974 = stablehlo.real_dynamic_slice %13969, %c_24, %from_elements_4915, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4916 = tensor.dim %13974, %c0 : tensor + %13975 = arith.index_cast %dim_4916 : index to i64 + %from_elements_4917 = tensor.from_elements %13975 : tensor<1xi64> + %13976 = stablehlo.dynamic_reshape %13974, %from_elements_4917 : (tensor, tensor<1xi64>) -> tensor + %dim_4918 = tensor.dim %13976, %c0 : tensor + %13977 = arith.index_cast %dim_4918 : index to i64 + %from_elements_4919 = tensor.from_elements %13977, %c1_i64 : tensor<2xi64> + %13978 = stablehlo.dynamic_reshape %13976, %from_elements_4919 : (tensor, tensor<2xi64>) -> tensor + %dim_4920 = tensor.dim %13978, %c0 : tensor + %13979 = arith.index_cast %dim_4920 : index to i64 + %from_elements_4921 = tensor.from_elements %c1_i64, %13979, %c4096_i64 : tensor<3xi64> + %13980 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4921, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4922 = tensor.dim %13980, %c1 : tensor<1x?x4096xi64> + %13981 = arith.index_cast %dim_4922 : index to i64 + %from_elements_4923 = tensor.from_elements %c1_i64, %13981, %c4096_i64, %c1_i64 : tensor<4xi64> + %13982 = stablehlo.dynamic_reshape %13980, %from_elements_4923 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13983 = stablehlo.dynamic_broadcast_in_dim %13978, %from_elements_4921, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4924 = tensor.dim %13983, %c1 : tensor<1x?x4096xi64> + %13984 = arith.index_cast %dim_4924 : index to i64 + %from_elements_4925 = tensor.from_elements %c1_i64, %13984, %c4096_i64, %c1_i64 : tensor<4xi64> + %13985 = stablehlo.dynamic_reshape %13983, %from_elements_4925 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13986 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4921, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4926 = tensor.dim %13986, %c1 : tensor<1x?x4096xi64> + %13987 = arith.index_cast %dim_4926 : index to i64 + %from_elements_4927 = tensor.from_elements %c1_i64, %13987, %c4096_i64, %c1_i64 : tensor<4xi64> + %13988 = stablehlo.dynamic_reshape %13986, %from_elements_4927 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %13989 = stablehlo.concatenate %13982, %13985, %13988, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %13990 = "stablehlo.gather"(%13660, %13989) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %13991 = shape.shape_of %13990 : tensor<1x?x4096xf32> -> tensor<3xindex> + %13992 = shape.num_elements %13991 : tensor<3xindex> -> index + %13993 = stablehlo.compute_reshape_shape %13992, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %13994 = stablehlo.dynamic_reshape %13990, %13993 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %13995 = stablehlo.dot %13994, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %13996 = stablehlo.logistic %13995 : tensor + %13997 = shape.shape_of %13996 : tensor -> tensor<2xindex> + %13998 = shape.shape_of %13995 : tensor -> tensor<2xindex> + %13999 = shape.cstr_broadcastable %13997, %13998 : tensor<2xindex>, tensor<2xindex> + %14000 = shape.assuming %13999 -> (tensor) { + %19688 = shape.broadcast %13997, %13998 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %13996, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13995, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14001 = shape.shape_of %14000 : tensor -> tensor<2xindex> + %14002 = shape.cstr_broadcastable %14001, %13998 : tensor<2xindex>, tensor<2xindex> + %14003 = shape.assuming %14002 -> (tensor) { + %19688 = shape.broadcast %14001, %13998 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14000, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %13995, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14004 = stablehlo.dot %14003, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4928 = tensor.dim %13976, %c0 : tensor + %14005 = arith.index_cast %dim_4928 : index to i64 + %from_elements_4929 = tensor.from_elements %14005, %c1_i64 : tensor<2xi64> + %14006 = stablehlo.dynamic_reshape %13976, %from_elements_4929 : (tensor, tensor<2xi64>) -> tensor + %dim_4930 = tensor.dim %13973, %c0 : tensor + %14007 = arith.index_cast %dim_4930 : index to i64 + %from_elements_4931 = tensor.from_elements %14007, %c1_i64 : tensor<2xi64> + %14008 = stablehlo.dynamic_reshape %13973, %from_elements_4931 : (tensor, tensor<2xi64>) -> tensor + %14009 = stablehlo.concatenate %14006, %14008, dim = 1 : (tensor, tensor) -> tensor + %14010 = "stablehlo.gather"(%13689, %14009) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14011 = shape.shape_of %14004 : tensor -> tensor<2xindex> + %14012 = shape.shape_of %14010 : tensor -> tensor<2xindex> + %14013 = shape.cstr_broadcastable %14011, %14012 : tensor<2xindex>, tensor<2xindex> + %14014 = shape.assuming %14013 -> (tensor) { + %19688 = shape.broadcast %14011, %14012 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14004, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14010, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14015 = shape.shape_of %14014 : tensor -> tensor<2xindex> + %14016 = stablehlo.dynamic_broadcast_in_dim %14014, %14015, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14017 = stablehlo.dynamic_broadcast_in_dim %213, %14015, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14018 = stablehlo.multiply %14016, %14017 : tensor + %dim_4932 = tensor.dim %13978, %c0 : tensor + %14019 = arith.index_cast %dim_4932 : index to i64 + %dim_4933 = tensor.dim %14014, %c0 : tensor + %14020 = arith.index_cast %dim_4933 : index to i64 + %14021 = arith.maxsi %14019, %14020 : i64 + %14022 = arith.index_cast %14021 : i64 to index + %from_elements_4934 = tensor.from_elements %14022, %c4096 : tensor<2xindex> + %14023 = stablehlo.dynamic_broadcast_in_dim %13978, %from_elements_4934, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4935 = tensor.dim %14023, %c0 : tensor + %14024 = arith.index_cast %dim_4935 : index to i64 + %from_elements_4936 = tensor.from_elements %14024, %c4096_i64 : tensor<2xi64> + %14025 = stablehlo.real_dynamic_slice %14018, %c_22, %from_elements_4936, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4937 = tensor.from_elements %14024, %c4096_i64, %c1_i64 : tensor<3xi64> + %14026 = stablehlo.dynamic_reshape %14023, %from_elements_4937 : (tensor, tensor<3xi64>) -> tensor + %14027 = stablehlo.dynamic_iota %from_elements_4937, dim = 1 : (tensor<3xi64>) -> tensor + %14028 = stablehlo.concatenate %14026, %14027, dim = 2 : (tensor, tensor) -> tensor + %14029 = "stablehlo.scatter"(%13966, %14028, %14025) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14030 = stablehlo.slice %13649 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14031 = stablehlo.reshape %14030 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14032 = stablehlo.custom_call @byteir.non_zero(%14031) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4938 = tensor.dim %14032, %c0 : tensor + %14033 = arith.index_cast %dim_4938 : index to i64 + %from_elements_4939 = tensor.from_elements %14033, %c1_i64 : tensor<2xi64> + %14034 = stablehlo.real_dynamic_slice %14032, %c_22, %from_elements_4939, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4940 = tensor.dim %14034, %c0 : tensor + %14035 = arith.index_cast %dim_4940 : index to i64 + %from_elements_4941 = tensor.from_elements %14035 : tensor<1xi64> + %14036 = stablehlo.dynamic_reshape %14034, %from_elements_4941 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4942 = tensor.from_elements %14033, %c2_i64 : tensor<2xi64> + %14037 = stablehlo.real_dynamic_slice %14032, %c_24, %from_elements_4942, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4943 = tensor.dim %14037, %c0 : tensor + %14038 = arith.index_cast %dim_4943 : index to i64 + %from_elements_4944 = tensor.from_elements %14038 : tensor<1xi64> + %14039 = stablehlo.dynamic_reshape %14037, %from_elements_4944 : (tensor, tensor<1xi64>) -> tensor + %dim_4945 = tensor.dim %14039, %c0 : tensor + %14040 = arith.index_cast %dim_4945 : index to i64 + %from_elements_4946 = tensor.from_elements %14040, %c1_i64 : tensor<2xi64> + %14041 = stablehlo.dynamic_reshape %14039, %from_elements_4946 : (tensor, tensor<2xi64>) -> tensor + %dim_4947 = tensor.dim %14041, %c0 : tensor + %14042 = arith.index_cast %dim_4947 : index to i64 + %from_elements_4948 = tensor.from_elements %c1_i64, %14042, %c4096_i64 : tensor<3xi64> + %14043 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4948, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4949 = tensor.dim %14043, %c1 : tensor<1x?x4096xi64> + %14044 = arith.index_cast %dim_4949 : index to i64 + %from_elements_4950 = tensor.from_elements %c1_i64, %14044, %c4096_i64, %c1_i64 : tensor<4xi64> + %14045 = stablehlo.dynamic_reshape %14043, %from_elements_4950 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14046 = stablehlo.dynamic_broadcast_in_dim %14041, %from_elements_4948, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4951 = tensor.dim %14046, %c1 : tensor<1x?x4096xi64> + %14047 = arith.index_cast %dim_4951 : index to i64 + %from_elements_4952 = tensor.from_elements %c1_i64, %14047, %c4096_i64, %c1_i64 : tensor<4xi64> + %14048 = stablehlo.dynamic_reshape %14046, %from_elements_4952 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14049 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4948, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4953 = tensor.dim %14049, %c1 : tensor<1x?x4096xi64> + %14050 = arith.index_cast %dim_4953 : index to i64 + %from_elements_4954 = tensor.from_elements %c1_i64, %14050, %c4096_i64, %c1_i64 : tensor<4xi64> + %14051 = stablehlo.dynamic_reshape %14049, %from_elements_4954 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14052 = stablehlo.concatenate %14045, %14048, %14051, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14053 = "stablehlo.gather"(%13660, %14052) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14054 = shape.shape_of %14053 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14055 = shape.num_elements %14054 : tensor<3xindex> -> index + %14056 = stablehlo.compute_reshape_shape %14055, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14057 = stablehlo.dynamic_reshape %14053, %14056 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14058 = stablehlo.dot %14057, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14059 = stablehlo.logistic %14058 : tensor + %14060 = shape.shape_of %14059 : tensor -> tensor<2xindex> + %14061 = shape.shape_of %14058 : tensor -> tensor<2xindex> + %14062 = shape.cstr_broadcastable %14060, %14061 : tensor<2xindex>, tensor<2xindex> + %14063 = shape.assuming %14062 -> (tensor) { + %19688 = shape.broadcast %14060, %14061 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14059, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14058, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14064 = shape.shape_of %14063 : tensor -> tensor<2xindex> + %14065 = shape.cstr_broadcastable %14064, %14061 : tensor<2xindex>, tensor<2xindex> + %14066 = shape.assuming %14065 -> (tensor) { + %19688 = shape.broadcast %14064, %14061 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14063, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14058, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14067 = stablehlo.dot %14066, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4955 = tensor.dim %14039, %c0 : tensor + %14068 = arith.index_cast %dim_4955 : index to i64 + %from_elements_4956 = tensor.from_elements %14068, %c1_i64 : tensor<2xi64> + %14069 = stablehlo.dynamic_reshape %14039, %from_elements_4956 : (tensor, tensor<2xi64>) -> tensor + %dim_4957 = tensor.dim %14036, %c0 : tensor + %14070 = arith.index_cast %dim_4957 : index to i64 + %from_elements_4958 = tensor.from_elements %14070, %c1_i64 : tensor<2xi64> + %14071 = stablehlo.dynamic_reshape %14036, %from_elements_4958 : (tensor, tensor<2xi64>) -> tensor + %14072 = stablehlo.concatenate %14069, %14071, dim = 1 : (tensor, tensor) -> tensor + %14073 = "stablehlo.gather"(%13689, %14072) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14074 = shape.shape_of %14067 : tensor -> tensor<2xindex> + %14075 = shape.shape_of %14073 : tensor -> tensor<2xindex> + %14076 = shape.cstr_broadcastable %14074, %14075 : tensor<2xindex>, tensor<2xindex> + %14077 = shape.assuming %14076 -> (tensor) { + %19688 = shape.broadcast %14074, %14075 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14067, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14073, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14078 = shape.shape_of %14077 : tensor -> tensor<2xindex> + %14079 = stablehlo.dynamic_broadcast_in_dim %14077, %14078, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14080 = stablehlo.dynamic_broadcast_in_dim %213, %14078, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14081 = stablehlo.multiply %14079, %14080 : tensor + %dim_4959 = tensor.dim %14041, %c0 : tensor + %14082 = arith.index_cast %dim_4959 : index to i64 + %dim_4960 = tensor.dim %14077, %c0 : tensor + %14083 = arith.index_cast %dim_4960 : index to i64 + %14084 = arith.maxsi %14082, %14083 : i64 + %14085 = arith.index_cast %14084 : i64 to index + %from_elements_4961 = tensor.from_elements %14085, %c4096 : tensor<2xindex> + %14086 = stablehlo.dynamic_broadcast_in_dim %14041, %from_elements_4961, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4962 = tensor.dim %14086, %c0 : tensor + %14087 = arith.index_cast %dim_4962 : index to i64 + %from_elements_4963 = tensor.from_elements %14087, %c4096_i64 : tensor<2xi64> + %14088 = stablehlo.real_dynamic_slice %14081, %c_22, %from_elements_4963, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4964 = tensor.from_elements %14087, %c4096_i64, %c1_i64 : tensor<3xi64> + %14089 = stablehlo.dynamic_reshape %14086, %from_elements_4964 : (tensor, tensor<3xi64>) -> tensor + %14090 = stablehlo.dynamic_iota %from_elements_4964, dim = 1 : (tensor<3xi64>) -> tensor + %14091 = stablehlo.concatenate %14089, %14090, dim = 2 : (tensor, tensor) -> tensor + %14092 = "stablehlo.scatter"(%14029, %14091, %14088) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14093 = stablehlo.slice %13649 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14094 = stablehlo.reshape %14093 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14095 = stablehlo.custom_call @byteir.non_zero(%14094) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4965 = tensor.dim %14095, %c0 : tensor + %14096 = arith.index_cast %dim_4965 : index to i64 + %from_elements_4966 = tensor.from_elements %14096, %c1_i64 : tensor<2xi64> + %14097 = stablehlo.real_dynamic_slice %14095, %c_22, %from_elements_4966, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4967 = tensor.dim %14097, %c0 : tensor + %14098 = arith.index_cast %dim_4967 : index to i64 + %from_elements_4968 = tensor.from_elements %14098 : tensor<1xi64> + %14099 = stablehlo.dynamic_reshape %14097, %from_elements_4968 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4969 = tensor.from_elements %14096, %c2_i64 : tensor<2xi64> + %14100 = stablehlo.real_dynamic_slice %14095, %c_24, %from_elements_4969, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4970 = tensor.dim %14100, %c0 : tensor + %14101 = arith.index_cast %dim_4970 : index to i64 + %from_elements_4971 = tensor.from_elements %14101 : tensor<1xi64> + %14102 = stablehlo.dynamic_reshape %14100, %from_elements_4971 : (tensor, tensor<1xi64>) -> tensor + %dim_4972 = tensor.dim %14102, %c0 : tensor + %14103 = arith.index_cast %dim_4972 : index to i64 + %from_elements_4973 = tensor.from_elements %14103, %c1_i64 : tensor<2xi64> + %14104 = stablehlo.dynamic_reshape %14102, %from_elements_4973 : (tensor, tensor<2xi64>) -> tensor + %dim_4974 = tensor.dim %14104, %c0 : tensor + %14105 = arith.index_cast %dim_4974 : index to i64 + %from_elements_4975 = tensor.from_elements %c1_i64, %14105, %c4096_i64 : tensor<3xi64> + %14106 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_4975, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4976 = tensor.dim %14106, %c1 : tensor<1x?x4096xi64> + %14107 = arith.index_cast %dim_4976 : index to i64 + %from_elements_4977 = tensor.from_elements %c1_i64, %14107, %c4096_i64, %c1_i64 : tensor<4xi64> + %14108 = stablehlo.dynamic_reshape %14106, %from_elements_4977 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14109 = stablehlo.dynamic_broadcast_in_dim %14104, %from_elements_4975, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4978 = tensor.dim %14109, %c1 : tensor<1x?x4096xi64> + %14110 = arith.index_cast %dim_4978 : index to i64 + %from_elements_4979 = tensor.from_elements %c1_i64, %14110, %c4096_i64, %c1_i64 : tensor<4xi64> + %14111 = stablehlo.dynamic_reshape %14109, %from_elements_4979 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14112 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_4975, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_4980 = tensor.dim %14112, %c1 : tensor<1x?x4096xi64> + %14113 = arith.index_cast %dim_4980 : index to i64 + %from_elements_4981 = tensor.from_elements %c1_i64, %14113, %c4096_i64, %c1_i64 : tensor<4xi64> + %14114 = stablehlo.dynamic_reshape %14112, %from_elements_4981 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14115 = stablehlo.concatenate %14108, %14111, %14114, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14116 = "stablehlo.gather"(%13660, %14115) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14117 = shape.shape_of %14116 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14118 = shape.num_elements %14117 : tensor<3xindex> -> index + %14119 = stablehlo.compute_reshape_shape %14118, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14120 = stablehlo.dynamic_reshape %14116, %14119 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14121 = stablehlo.dot %14120, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14122 = stablehlo.logistic %14121 : tensor + %14123 = shape.shape_of %14122 : tensor -> tensor<2xindex> + %14124 = shape.shape_of %14121 : tensor -> tensor<2xindex> + %14125 = shape.cstr_broadcastable %14123, %14124 : tensor<2xindex>, tensor<2xindex> + %14126 = shape.assuming %14125 -> (tensor) { + %19688 = shape.broadcast %14123, %14124 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14122, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14121, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14127 = shape.shape_of %14126 : tensor -> tensor<2xindex> + %14128 = shape.cstr_broadcastable %14127, %14124 : tensor<2xindex>, tensor<2xindex> + %14129 = shape.assuming %14128 -> (tensor) { + %19688 = shape.broadcast %14127, %14124 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14126, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14121, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14130 = stablehlo.dot %14129, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_4982 = tensor.dim %14102, %c0 : tensor + %14131 = arith.index_cast %dim_4982 : index to i64 + %from_elements_4983 = tensor.from_elements %14131, %c1_i64 : tensor<2xi64> + %14132 = stablehlo.dynamic_reshape %14102, %from_elements_4983 : (tensor, tensor<2xi64>) -> tensor + %dim_4984 = tensor.dim %14099, %c0 : tensor + %14133 = arith.index_cast %dim_4984 : index to i64 + %from_elements_4985 = tensor.from_elements %14133, %c1_i64 : tensor<2xi64> + %14134 = stablehlo.dynamic_reshape %14099, %from_elements_4985 : (tensor, tensor<2xi64>) -> tensor + %14135 = stablehlo.concatenate %14132, %14134, dim = 1 : (tensor, tensor) -> tensor + %14136 = "stablehlo.gather"(%13689, %14135) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14137 = shape.shape_of %14130 : tensor -> tensor<2xindex> + %14138 = shape.shape_of %14136 : tensor -> tensor<2xindex> + %14139 = shape.cstr_broadcastable %14137, %14138 : tensor<2xindex>, tensor<2xindex> + %14140 = shape.assuming %14139 -> (tensor) { + %19688 = shape.broadcast %14137, %14138 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14130, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14136, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14141 = shape.shape_of %14140 : tensor -> tensor<2xindex> + %14142 = stablehlo.dynamic_broadcast_in_dim %14140, %14141, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14143 = stablehlo.dynamic_broadcast_in_dim %213, %14141, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14144 = stablehlo.multiply %14142, %14143 : tensor + %dim_4986 = tensor.dim %14104, %c0 : tensor + %14145 = arith.index_cast %dim_4986 : index to i64 + %dim_4987 = tensor.dim %14140, %c0 : tensor + %14146 = arith.index_cast %dim_4987 : index to i64 + %14147 = arith.maxsi %14145, %14146 : i64 + %14148 = arith.index_cast %14147 : i64 to index + %from_elements_4988 = tensor.from_elements %14148, %c4096 : tensor<2xindex> + %14149 = stablehlo.dynamic_broadcast_in_dim %14104, %from_elements_4988, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_4989 = tensor.dim %14149, %c0 : tensor + %14150 = arith.index_cast %dim_4989 : index to i64 + %from_elements_4990 = tensor.from_elements %14150, %c4096_i64 : tensor<2xi64> + %14151 = stablehlo.real_dynamic_slice %14144, %c_22, %from_elements_4990, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_4991 = tensor.from_elements %14150, %c4096_i64, %c1_i64 : tensor<3xi64> + %14152 = stablehlo.dynamic_reshape %14149, %from_elements_4991 : (tensor, tensor<3xi64>) -> tensor + %14153 = stablehlo.dynamic_iota %from_elements_4991, dim = 1 : (tensor<3xi64>) -> tensor + %14154 = stablehlo.concatenate %14152, %14153, dim = 2 : (tensor, tensor) -> tensor + %14155 = "stablehlo.scatter"(%14092, %14154, %14151) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14156 = stablehlo.reshape %14155 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %14157 = stablehlo.add %13622, %14156 : tensor<3x1x4096xf32> + %14158 = stablehlo.broadcast_in_dim %14157, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %14159 = stablehlo.power %14158, %15 : tensor<3x1x4096xf32> + %14160 = stablehlo.reduce(%14159 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %14161 = stablehlo.reshape %14160 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %14162 = stablehlo.broadcast_in_dim %14161, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %14163 = stablehlo.divide %14162, %21 : tensor<3x1x1xf32> + %14164 = stablehlo.broadcast_in_dim %14163, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %14165 = stablehlo.add %14164, %25 : tensor<3x1x1xf32> + %14166 = stablehlo.rsqrt %14165 : tensor<3x1x1xf32> + %14167 = stablehlo.broadcast_in_dim %14166, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %14168 = stablehlo.multiply %14158, %14167 : tensor<3x1x4096xf32> + %14169 = stablehlo.broadcast_in_dim %14168, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %14170 = stablehlo.multiply %14169, %31 : tensor<3x1x4096xf32> + %14171 = stablehlo.reshape %14170 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %14172 = stablehlo.dot %14171, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %14173 = stablehlo.reshape %14172 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %14174 = stablehlo.dot %14171, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %14175 = stablehlo.reshape %14174 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %14176 = stablehlo.reshape %14173 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %14177 = stablehlo.transpose %14176, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %14178 = stablehlo.reshape %14175 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %14179 = stablehlo.transpose %14178, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %14180 = stablehlo.slice %arg46 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %14181 = stablehlo.slice %arg47 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %14182 = "stablehlo.gather"(%14180, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %14183 = stablehlo.reshape %14182 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %14184 = "stablehlo.gather"(%14181, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %14185 = stablehlo.reshape %14184 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %14186 = stablehlo.broadcast_in_dim %14177, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %14187 = stablehlo.broadcast_in_dim %14183, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %14188 = stablehlo.multiply %14186, %14187 : tensor<3x32x1x128xf32> + %14189 = stablehlo.slice %14177 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %14190 = stablehlo.slice %14177 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %14191 = stablehlo.negate %14190 : tensor<3x32x1x64xf32> + %14192 = stablehlo.concatenate %14191, %14189, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %14193 = stablehlo.broadcast_in_dim %14192, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %14194 = stablehlo.broadcast_in_dim %14185, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %14195 = stablehlo.multiply %14193, %14194 : tensor<3x32x1x128xf32> + %14196 = stablehlo.add %14188, %14195 : tensor<3x32x1x128xf32> + %14197 = stablehlo.broadcast_in_dim %14179, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %14198 = stablehlo.broadcast_in_dim %14183, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %14199 = stablehlo.multiply %14197, %14198 : tensor<3x8x1x128xf32> + %14200 = stablehlo.slice %14179 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %14201 = stablehlo.slice %14179 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %14202 = stablehlo.negate %14201 : tensor<3x8x1x64xf32> + %14203 = stablehlo.concatenate %14202, %14200, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %14204 = stablehlo.broadcast_in_dim %14203, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %14205 = stablehlo.broadcast_in_dim %14185, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %14206 = stablehlo.multiply %14204, %14205 : tensor<3x8x1x128xf32> + %14207 = stablehlo.add %14199, %14206 : tensor<3x8x1x128xf32> + %14208 = stablehlo.concatenate %arg111, %14207, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %14209 = stablehlo.concatenate %arg112, %14179, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %14210 = stablehlo.reshape %14208 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %14211 = stablehlo.broadcast_in_dim %14210, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %14212 = stablehlo.reshape %14211 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %14213 = stablehlo.reshape %14209 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %14214 = stablehlo.broadcast_in_dim %14213, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %14215 = stablehlo.reshape %14214 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %14216 = stablehlo.transpose %14212, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %14217 = stablehlo.reshape %14196 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %14218 = stablehlo.reshape %14216 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %14219 = stablehlo.broadcast_in_dim %14218, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %14220 = stablehlo.dot_general %14217, %14219, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %14221 = stablehlo.reshape %14220 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %14222 = stablehlo.broadcast_in_dim %14221, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %14223 = stablehlo.divide %14222, %89 : tensor<3x32x1x8xf32> + %14224 = stablehlo.custom_call @byteir.softmax(%14223) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %14225 = stablehlo.reshape %14224 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %14226 = stablehlo.reshape %14215 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %14227 = stablehlo.broadcast_in_dim %14226, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %14228 = stablehlo.dot_general %14225, %14227, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %14229 = stablehlo.reshape %14228 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %14230 = stablehlo.transpose %14229, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %14231 = stablehlo.reshape %14230 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %14232 = stablehlo.reshape %14231 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %14233 = stablehlo.dot %14232, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %14234 = stablehlo.reshape %14233 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %14235 = stablehlo.add %14157, %14234 : tensor<3x1x4096xf32> + %14236 = stablehlo.broadcast_in_dim %14235, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %14237 = stablehlo.power %14236, %15 : tensor<3x1x4096xf32> + %14238 = stablehlo.reduce(%14237 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %14239 = stablehlo.reshape %14238 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %14240 = stablehlo.broadcast_in_dim %14239, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %14241 = stablehlo.divide %14240, %21 : tensor<3x1x1xf32> + %14242 = stablehlo.broadcast_in_dim %14241, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %14243 = stablehlo.add %14242, %25 : tensor<3x1x1xf32> + %14244 = stablehlo.rsqrt %14243 : tensor<3x1x1xf32> + %14245 = stablehlo.broadcast_in_dim %14244, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %14246 = stablehlo.multiply %14236, %14245 : tensor<3x1x4096xf32> + %14247 = stablehlo.broadcast_in_dim %14246, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %14248 = stablehlo.multiply %14247, %31 : tensor<3x1x4096xf32> + %14249 = stablehlo.reshape %14248 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %14250 = stablehlo.dot %14249, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %14251 = stablehlo.custom_call @byteir.softmax(%14250) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %14252:2 = stablehlo.custom_call @byteir.top_k(%14251) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %14253 = stablehlo.reduce(%14252#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %14254 = stablehlo.reshape %14253 : (tensor<3xf32>) -> tensor<3x1xf32> + %14255 = stablehlo.broadcast_in_dim %14252#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %14256 = stablehlo.broadcast_in_dim %14254, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %14257 = stablehlo.divide %14255, %14256 : tensor<3x2xf32> + %14258 = stablehlo.reshape %14252#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %14259 = stablehlo.broadcast_in_dim %14258, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %14260 = stablehlo.compare EQ, %14259, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %14261 = stablehlo.convert %14260 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %14262 = stablehlo.transpose %14261, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %14263 = stablehlo.slice %14262 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14264 = stablehlo.reshape %14263 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14265 = stablehlo.custom_call @byteir.non_zero(%14264) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_4992 = tensor.dim %14265, %c0 : tensor + %14266 = arith.index_cast %dim_4992 : index to i64 + %from_elements_4993 = tensor.from_elements %14266, %c1_i64 : tensor<2xi64> + %14267 = stablehlo.real_dynamic_slice %14265, %c_22, %from_elements_4993, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4994 = tensor.dim %14267, %c0 : tensor + %14268 = arith.index_cast %dim_4994 : index to i64 + %from_elements_4995 = tensor.from_elements %14268 : tensor<1xi64> + %14269 = stablehlo.dynamic_reshape %14267, %from_elements_4995 : (tensor, tensor<1xi64>) -> tensor + %from_elements_4996 = tensor.from_elements %14266, %c2_i64 : tensor<2xi64> + %14270 = stablehlo.real_dynamic_slice %14265, %c_24, %from_elements_4996, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_4997 = tensor.dim %14270, %c0 : tensor + %14271 = arith.index_cast %dim_4997 : index to i64 + %from_elements_4998 = tensor.from_elements %14271 : tensor<1xi64> + %14272 = stablehlo.dynamic_reshape %14270, %from_elements_4998 : (tensor, tensor<1xi64>) -> tensor + %14273 = stablehlo.reshape %14249 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_4999 = tensor.dim %14272, %c0 : tensor + %14274 = arith.index_cast %dim_4999 : index to i64 + %from_elements_5000 = tensor.from_elements %14274, %c1_i64 : tensor<2xi64> + %14275 = stablehlo.dynamic_reshape %14272, %from_elements_5000 : (tensor, tensor<2xi64>) -> tensor + %dim_5001 = tensor.dim %14275, %c0 : tensor + %14276 = arith.index_cast %dim_5001 : index to i64 + %from_elements_5002 = tensor.from_elements %c1_i64, %14276, %c4096_i64 : tensor<3xi64> + %14277 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5002, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5003 = tensor.dim %14277, %c1 : tensor<1x?x4096xi64> + %14278 = arith.index_cast %dim_5003 : index to i64 + %from_elements_5004 = tensor.from_elements %c1_i64, %14278, %c4096_i64, %c1_i64 : tensor<4xi64> + %14279 = stablehlo.dynamic_reshape %14277, %from_elements_5004 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14280 = stablehlo.dynamic_broadcast_in_dim %14275, %from_elements_5002, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5005 = tensor.dim %14280, %c1 : tensor<1x?x4096xi64> + %14281 = arith.index_cast %dim_5005 : index to i64 + %from_elements_5006 = tensor.from_elements %c1_i64, %14281, %c4096_i64, %c1_i64 : tensor<4xi64> + %14282 = stablehlo.dynamic_reshape %14280, %from_elements_5006 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14283 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5002, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5007 = tensor.dim %14283, %c1 : tensor<1x?x4096xi64> + %14284 = arith.index_cast %dim_5007 : index to i64 + %from_elements_5008 = tensor.from_elements %c1_i64, %14284, %c4096_i64, %c1_i64 : tensor<4xi64> + %14285 = stablehlo.dynamic_reshape %14283, %from_elements_5008 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14286 = stablehlo.concatenate %14279, %14282, %14285, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14287 = "stablehlo.gather"(%14273, %14286) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14288 = shape.shape_of %14287 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14289 = shape.num_elements %14288 : tensor<3xindex> -> index + %14290 = stablehlo.compute_reshape_shape %14289, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14291 = stablehlo.dynamic_reshape %14287, %14290 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14292 = stablehlo.dot %14291, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14293 = stablehlo.logistic %14292 : tensor + %14294 = shape.shape_of %14293 : tensor -> tensor<2xindex> + %14295 = shape.shape_of %14292 : tensor -> tensor<2xindex> + %14296 = shape.cstr_broadcastable %14294, %14295 : tensor<2xindex>, tensor<2xindex> + %14297 = shape.assuming %14296 -> (tensor) { + %19688 = shape.broadcast %14294, %14295 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14293, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14292, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14298 = shape.shape_of %14297 : tensor -> tensor<2xindex> + %14299 = shape.cstr_broadcastable %14298, %14295 : tensor<2xindex>, tensor<2xindex> + %14300 = shape.assuming %14299 -> (tensor) { + %19688 = shape.broadcast %14298, %14295 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14297, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14292, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14301 = stablehlo.dot %14300, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %14302 = stablehlo.reshape %14257 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_5009 = tensor.dim %14272, %c0 : tensor + %14303 = arith.index_cast %dim_5009 : index to i64 + %from_elements_5010 = tensor.from_elements %14303, %c1_i64 : tensor<2xi64> + %14304 = stablehlo.dynamic_reshape %14272, %from_elements_5010 : (tensor, tensor<2xi64>) -> tensor + %dim_5011 = tensor.dim %14269, %c0 : tensor + %14305 = arith.index_cast %dim_5011 : index to i64 + %from_elements_5012 = tensor.from_elements %14305, %c1_i64 : tensor<2xi64> + %14306 = stablehlo.dynamic_reshape %14269, %from_elements_5012 : (tensor, tensor<2xi64>) -> tensor + %14307 = stablehlo.concatenate %14304, %14306, dim = 1 : (tensor, tensor) -> tensor + %14308 = "stablehlo.gather"(%14302, %14307) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14309 = shape.shape_of %14301 : tensor -> tensor<2xindex> + %14310 = shape.shape_of %14308 : tensor -> tensor<2xindex> + %14311 = shape.cstr_broadcastable %14309, %14310 : tensor<2xindex>, tensor<2xindex> + %14312 = shape.assuming %14311 -> (tensor) { + %19688 = shape.broadcast %14309, %14310 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14301, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14308, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14313 = shape.shape_of %14312 : tensor -> tensor<2xindex> + %14314 = stablehlo.dynamic_broadcast_in_dim %14312, %14313, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14315 = stablehlo.dynamic_broadcast_in_dim %213, %14313, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14316 = stablehlo.multiply %14314, %14315 : tensor + %dim_5013 = tensor.dim %14275, %c0 : tensor + %14317 = arith.index_cast %dim_5013 : index to i64 + %dim_5014 = tensor.dim %14312, %c0 : tensor + %14318 = arith.index_cast %dim_5014 : index to i64 + %14319 = arith.maxsi %14317, %14318 : i64 + %14320 = arith.index_cast %14319 : i64 to index + %from_elements_5015 = tensor.from_elements %14320, %c4096 : tensor<2xindex> + %14321 = stablehlo.dynamic_broadcast_in_dim %14275, %from_elements_5015, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5016 = tensor.dim %14321, %c0 : tensor + %14322 = arith.index_cast %dim_5016 : index to i64 + %from_elements_5017 = tensor.from_elements %14322, %c4096_i64 : tensor<2xi64> + %14323 = stablehlo.real_dynamic_slice %14316, %c_22, %from_elements_5017, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5018 = tensor.from_elements %14322, %c4096_i64, %c1_i64 : tensor<3xi64> + %14324 = stablehlo.dynamic_reshape %14321, %from_elements_5018 : (tensor, tensor<3xi64>) -> tensor + %14325 = stablehlo.dynamic_iota %from_elements_5018, dim = 1 : (tensor<3xi64>) -> tensor + %14326 = stablehlo.concatenate %14324, %14325, dim = 2 : (tensor, tensor) -> tensor + %14327 = "stablehlo.scatter"(%cst_2, %14326, %14323) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14328 = stablehlo.slice %14262 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14329 = stablehlo.reshape %14328 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14330 = stablehlo.custom_call @byteir.non_zero(%14329) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5019 = tensor.dim %14330, %c0 : tensor + %14331 = arith.index_cast %dim_5019 : index to i64 + %from_elements_5020 = tensor.from_elements %14331, %c1_i64 : tensor<2xi64> + %14332 = stablehlo.real_dynamic_slice %14330, %c_22, %from_elements_5020, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5021 = tensor.dim %14332, %c0 : tensor + %14333 = arith.index_cast %dim_5021 : index to i64 + %from_elements_5022 = tensor.from_elements %14333 : tensor<1xi64> + %14334 = stablehlo.dynamic_reshape %14332, %from_elements_5022 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5023 = tensor.from_elements %14331, %c2_i64 : tensor<2xi64> + %14335 = stablehlo.real_dynamic_slice %14330, %c_24, %from_elements_5023, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5024 = tensor.dim %14335, %c0 : tensor + %14336 = arith.index_cast %dim_5024 : index to i64 + %from_elements_5025 = tensor.from_elements %14336 : tensor<1xi64> + %14337 = stablehlo.dynamic_reshape %14335, %from_elements_5025 : (tensor, tensor<1xi64>) -> tensor + %dim_5026 = tensor.dim %14337, %c0 : tensor + %14338 = arith.index_cast %dim_5026 : index to i64 + %from_elements_5027 = tensor.from_elements %14338, %c1_i64 : tensor<2xi64> + %14339 = stablehlo.dynamic_reshape %14337, %from_elements_5027 : (tensor, tensor<2xi64>) -> tensor + %dim_5028 = tensor.dim %14339, %c0 : tensor + %14340 = arith.index_cast %dim_5028 : index to i64 + %from_elements_5029 = tensor.from_elements %c1_i64, %14340, %c4096_i64 : tensor<3xi64> + %14341 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5029, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5030 = tensor.dim %14341, %c1 : tensor<1x?x4096xi64> + %14342 = arith.index_cast %dim_5030 : index to i64 + %from_elements_5031 = tensor.from_elements %c1_i64, %14342, %c4096_i64, %c1_i64 : tensor<4xi64> + %14343 = stablehlo.dynamic_reshape %14341, %from_elements_5031 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14344 = stablehlo.dynamic_broadcast_in_dim %14339, %from_elements_5029, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5032 = tensor.dim %14344, %c1 : tensor<1x?x4096xi64> + %14345 = arith.index_cast %dim_5032 : index to i64 + %from_elements_5033 = tensor.from_elements %c1_i64, %14345, %c4096_i64, %c1_i64 : tensor<4xi64> + %14346 = stablehlo.dynamic_reshape %14344, %from_elements_5033 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14347 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5029, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5034 = tensor.dim %14347, %c1 : tensor<1x?x4096xi64> + %14348 = arith.index_cast %dim_5034 : index to i64 + %from_elements_5035 = tensor.from_elements %c1_i64, %14348, %c4096_i64, %c1_i64 : tensor<4xi64> + %14349 = stablehlo.dynamic_reshape %14347, %from_elements_5035 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14350 = stablehlo.concatenate %14343, %14346, %14349, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14351 = "stablehlo.gather"(%14273, %14350) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14352 = shape.shape_of %14351 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14353 = shape.num_elements %14352 : tensor<3xindex> -> index + %14354 = stablehlo.compute_reshape_shape %14353, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14355 = stablehlo.dynamic_reshape %14351, %14354 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14356 = stablehlo.dot %14355, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14357 = stablehlo.logistic %14356 : tensor + %14358 = shape.shape_of %14357 : tensor -> tensor<2xindex> + %14359 = shape.shape_of %14356 : tensor -> tensor<2xindex> + %14360 = shape.cstr_broadcastable %14358, %14359 : tensor<2xindex>, tensor<2xindex> + %14361 = shape.assuming %14360 -> (tensor) { + %19688 = shape.broadcast %14358, %14359 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14357, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14356, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14362 = shape.shape_of %14361 : tensor -> tensor<2xindex> + %14363 = shape.cstr_broadcastable %14362, %14359 : tensor<2xindex>, tensor<2xindex> + %14364 = shape.assuming %14363 -> (tensor) { + %19688 = shape.broadcast %14362, %14359 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14361, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14356, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14365 = stablehlo.dot %14364, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5036 = tensor.dim %14337, %c0 : tensor + %14366 = arith.index_cast %dim_5036 : index to i64 + %from_elements_5037 = tensor.from_elements %14366, %c1_i64 : tensor<2xi64> + %14367 = stablehlo.dynamic_reshape %14337, %from_elements_5037 : (tensor, tensor<2xi64>) -> tensor + %dim_5038 = tensor.dim %14334, %c0 : tensor + %14368 = arith.index_cast %dim_5038 : index to i64 + %from_elements_5039 = tensor.from_elements %14368, %c1_i64 : tensor<2xi64> + %14369 = stablehlo.dynamic_reshape %14334, %from_elements_5039 : (tensor, tensor<2xi64>) -> tensor + %14370 = stablehlo.concatenate %14367, %14369, dim = 1 : (tensor, tensor) -> tensor + %14371 = "stablehlo.gather"(%14302, %14370) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14372 = shape.shape_of %14365 : tensor -> tensor<2xindex> + %14373 = shape.shape_of %14371 : tensor -> tensor<2xindex> + %14374 = shape.cstr_broadcastable %14372, %14373 : tensor<2xindex>, tensor<2xindex> + %14375 = shape.assuming %14374 -> (tensor) { + %19688 = shape.broadcast %14372, %14373 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14365, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14371, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14376 = shape.shape_of %14375 : tensor -> tensor<2xindex> + %14377 = stablehlo.dynamic_broadcast_in_dim %14375, %14376, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14378 = stablehlo.dynamic_broadcast_in_dim %213, %14376, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14379 = stablehlo.multiply %14377, %14378 : tensor + %dim_5040 = tensor.dim %14339, %c0 : tensor + %14380 = arith.index_cast %dim_5040 : index to i64 + %dim_5041 = tensor.dim %14375, %c0 : tensor + %14381 = arith.index_cast %dim_5041 : index to i64 + %14382 = arith.maxsi %14380, %14381 : i64 + %14383 = arith.index_cast %14382 : i64 to index + %from_elements_5042 = tensor.from_elements %14383, %c4096 : tensor<2xindex> + %14384 = stablehlo.dynamic_broadcast_in_dim %14339, %from_elements_5042, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5043 = tensor.dim %14384, %c0 : tensor + %14385 = arith.index_cast %dim_5043 : index to i64 + %from_elements_5044 = tensor.from_elements %14385, %c4096_i64 : tensor<2xi64> + %14386 = stablehlo.real_dynamic_slice %14379, %c_22, %from_elements_5044, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5045 = tensor.from_elements %14385, %c4096_i64, %c1_i64 : tensor<3xi64> + %14387 = stablehlo.dynamic_reshape %14384, %from_elements_5045 : (tensor, tensor<3xi64>) -> tensor + %14388 = stablehlo.dynamic_iota %from_elements_5045, dim = 1 : (tensor<3xi64>) -> tensor + %14389 = stablehlo.concatenate %14387, %14388, dim = 2 : (tensor, tensor) -> tensor + %14390 = "stablehlo.scatter"(%14327, %14389, %14386) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14391 = stablehlo.slice %14262 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14392 = stablehlo.reshape %14391 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14393 = stablehlo.custom_call @byteir.non_zero(%14392) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5046 = tensor.dim %14393, %c0 : tensor + %14394 = arith.index_cast %dim_5046 : index to i64 + %from_elements_5047 = tensor.from_elements %14394, %c1_i64 : tensor<2xi64> + %14395 = stablehlo.real_dynamic_slice %14393, %c_22, %from_elements_5047, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5048 = tensor.dim %14395, %c0 : tensor + %14396 = arith.index_cast %dim_5048 : index to i64 + %from_elements_5049 = tensor.from_elements %14396 : tensor<1xi64> + %14397 = stablehlo.dynamic_reshape %14395, %from_elements_5049 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5050 = tensor.from_elements %14394, %c2_i64 : tensor<2xi64> + %14398 = stablehlo.real_dynamic_slice %14393, %c_24, %from_elements_5050, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5051 = tensor.dim %14398, %c0 : tensor + %14399 = arith.index_cast %dim_5051 : index to i64 + %from_elements_5052 = tensor.from_elements %14399 : tensor<1xi64> + %14400 = stablehlo.dynamic_reshape %14398, %from_elements_5052 : (tensor, tensor<1xi64>) -> tensor + %dim_5053 = tensor.dim %14400, %c0 : tensor + %14401 = arith.index_cast %dim_5053 : index to i64 + %from_elements_5054 = tensor.from_elements %14401, %c1_i64 : tensor<2xi64> + %14402 = stablehlo.dynamic_reshape %14400, %from_elements_5054 : (tensor, tensor<2xi64>) -> tensor + %dim_5055 = tensor.dim %14402, %c0 : tensor + %14403 = arith.index_cast %dim_5055 : index to i64 + %from_elements_5056 = tensor.from_elements %c1_i64, %14403, %c4096_i64 : tensor<3xi64> + %14404 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5056, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5057 = tensor.dim %14404, %c1 : tensor<1x?x4096xi64> + %14405 = arith.index_cast %dim_5057 : index to i64 + %from_elements_5058 = tensor.from_elements %c1_i64, %14405, %c4096_i64, %c1_i64 : tensor<4xi64> + %14406 = stablehlo.dynamic_reshape %14404, %from_elements_5058 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14407 = stablehlo.dynamic_broadcast_in_dim %14402, %from_elements_5056, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5059 = tensor.dim %14407, %c1 : tensor<1x?x4096xi64> + %14408 = arith.index_cast %dim_5059 : index to i64 + %from_elements_5060 = tensor.from_elements %c1_i64, %14408, %c4096_i64, %c1_i64 : tensor<4xi64> + %14409 = stablehlo.dynamic_reshape %14407, %from_elements_5060 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14410 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5056, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5061 = tensor.dim %14410, %c1 : tensor<1x?x4096xi64> + %14411 = arith.index_cast %dim_5061 : index to i64 + %from_elements_5062 = tensor.from_elements %c1_i64, %14411, %c4096_i64, %c1_i64 : tensor<4xi64> + %14412 = stablehlo.dynamic_reshape %14410, %from_elements_5062 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14413 = stablehlo.concatenate %14406, %14409, %14412, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14414 = "stablehlo.gather"(%14273, %14413) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14415 = shape.shape_of %14414 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14416 = shape.num_elements %14415 : tensor<3xindex> -> index + %14417 = stablehlo.compute_reshape_shape %14416, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14418 = stablehlo.dynamic_reshape %14414, %14417 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14419 = stablehlo.dot %14418, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14420 = stablehlo.logistic %14419 : tensor + %14421 = shape.shape_of %14420 : tensor -> tensor<2xindex> + %14422 = shape.shape_of %14419 : tensor -> tensor<2xindex> + %14423 = shape.cstr_broadcastable %14421, %14422 : tensor<2xindex>, tensor<2xindex> + %14424 = shape.assuming %14423 -> (tensor) { + %19688 = shape.broadcast %14421, %14422 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14420, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14419, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14425 = shape.shape_of %14424 : tensor -> tensor<2xindex> + %14426 = shape.cstr_broadcastable %14425, %14422 : tensor<2xindex>, tensor<2xindex> + %14427 = shape.assuming %14426 -> (tensor) { + %19688 = shape.broadcast %14425, %14422 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14424, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14419, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14428 = stablehlo.dot %14427, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5063 = tensor.dim %14400, %c0 : tensor + %14429 = arith.index_cast %dim_5063 : index to i64 + %from_elements_5064 = tensor.from_elements %14429, %c1_i64 : tensor<2xi64> + %14430 = stablehlo.dynamic_reshape %14400, %from_elements_5064 : (tensor, tensor<2xi64>) -> tensor + %dim_5065 = tensor.dim %14397, %c0 : tensor + %14431 = arith.index_cast %dim_5065 : index to i64 + %from_elements_5066 = tensor.from_elements %14431, %c1_i64 : tensor<2xi64> + %14432 = stablehlo.dynamic_reshape %14397, %from_elements_5066 : (tensor, tensor<2xi64>) -> tensor + %14433 = stablehlo.concatenate %14430, %14432, dim = 1 : (tensor, tensor) -> tensor + %14434 = "stablehlo.gather"(%14302, %14433) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14435 = shape.shape_of %14428 : tensor -> tensor<2xindex> + %14436 = shape.shape_of %14434 : tensor -> tensor<2xindex> + %14437 = shape.cstr_broadcastable %14435, %14436 : tensor<2xindex>, tensor<2xindex> + %14438 = shape.assuming %14437 -> (tensor) { + %19688 = shape.broadcast %14435, %14436 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14428, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14434, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14439 = shape.shape_of %14438 : tensor -> tensor<2xindex> + %14440 = stablehlo.dynamic_broadcast_in_dim %14438, %14439, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14441 = stablehlo.dynamic_broadcast_in_dim %213, %14439, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14442 = stablehlo.multiply %14440, %14441 : tensor + %dim_5067 = tensor.dim %14402, %c0 : tensor + %14443 = arith.index_cast %dim_5067 : index to i64 + %dim_5068 = tensor.dim %14438, %c0 : tensor + %14444 = arith.index_cast %dim_5068 : index to i64 + %14445 = arith.maxsi %14443, %14444 : i64 + %14446 = arith.index_cast %14445 : i64 to index + %from_elements_5069 = tensor.from_elements %14446, %c4096 : tensor<2xindex> + %14447 = stablehlo.dynamic_broadcast_in_dim %14402, %from_elements_5069, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5070 = tensor.dim %14447, %c0 : tensor + %14448 = arith.index_cast %dim_5070 : index to i64 + %from_elements_5071 = tensor.from_elements %14448, %c4096_i64 : tensor<2xi64> + %14449 = stablehlo.real_dynamic_slice %14442, %c_22, %from_elements_5071, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5072 = tensor.from_elements %14448, %c4096_i64, %c1_i64 : tensor<3xi64> + %14450 = stablehlo.dynamic_reshape %14447, %from_elements_5072 : (tensor, tensor<3xi64>) -> tensor + %14451 = stablehlo.dynamic_iota %from_elements_5072, dim = 1 : (tensor<3xi64>) -> tensor + %14452 = stablehlo.concatenate %14450, %14451, dim = 2 : (tensor, tensor) -> tensor + %14453 = "stablehlo.scatter"(%14390, %14452, %14449) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14454 = stablehlo.slice %14262 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14455 = stablehlo.reshape %14454 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14456 = stablehlo.custom_call @byteir.non_zero(%14455) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5073 = tensor.dim %14456, %c0 : tensor + %14457 = arith.index_cast %dim_5073 : index to i64 + %from_elements_5074 = tensor.from_elements %14457, %c1_i64 : tensor<2xi64> + %14458 = stablehlo.real_dynamic_slice %14456, %c_22, %from_elements_5074, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5075 = tensor.dim %14458, %c0 : tensor + %14459 = arith.index_cast %dim_5075 : index to i64 + %from_elements_5076 = tensor.from_elements %14459 : tensor<1xi64> + %14460 = stablehlo.dynamic_reshape %14458, %from_elements_5076 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5077 = tensor.from_elements %14457, %c2_i64 : tensor<2xi64> + %14461 = stablehlo.real_dynamic_slice %14456, %c_24, %from_elements_5077, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5078 = tensor.dim %14461, %c0 : tensor + %14462 = arith.index_cast %dim_5078 : index to i64 + %from_elements_5079 = tensor.from_elements %14462 : tensor<1xi64> + %14463 = stablehlo.dynamic_reshape %14461, %from_elements_5079 : (tensor, tensor<1xi64>) -> tensor + %dim_5080 = tensor.dim %14463, %c0 : tensor + %14464 = arith.index_cast %dim_5080 : index to i64 + %from_elements_5081 = tensor.from_elements %14464, %c1_i64 : tensor<2xi64> + %14465 = stablehlo.dynamic_reshape %14463, %from_elements_5081 : (tensor, tensor<2xi64>) -> tensor + %dim_5082 = tensor.dim %14465, %c0 : tensor + %14466 = arith.index_cast %dim_5082 : index to i64 + %from_elements_5083 = tensor.from_elements %c1_i64, %14466, %c4096_i64 : tensor<3xi64> + %14467 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5083, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5084 = tensor.dim %14467, %c1 : tensor<1x?x4096xi64> + %14468 = arith.index_cast %dim_5084 : index to i64 + %from_elements_5085 = tensor.from_elements %c1_i64, %14468, %c4096_i64, %c1_i64 : tensor<4xi64> + %14469 = stablehlo.dynamic_reshape %14467, %from_elements_5085 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14470 = stablehlo.dynamic_broadcast_in_dim %14465, %from_elements_5083, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5086 = tensor.dim %14470, %c1 : tensor<1x?x4096xi64> + %14471 = arith.index_cast %dim_5086 : index to i64 + %from_elements_5087 = tensor.from_elements %c1_i64, %14471, %c4096_i64, %c1_i64 : tensor<4xi64> + %14472 = stablehlo.dynamic_reshape %14470, %from_elements_5087 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14473 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5083, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5088 = tensor.dim %14473, %c1 : tensor<1x?x4096xi64> + %14474 = arith.index_cast %dim_5088 : index to i64 + %from_elements_5089 = tensor.from_elements %c1_i64, %14474, %c4096_i64, %c1_i64 : tensor<4xi64> + %14475 = stablehlo.dynamic_reshape %14473, %from_elements_5089 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14476 = stablehlo.concatenate %14469, %14472, %14475, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14477 = "stablehlo.gather"(%14273, %14476) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14478 = shape.shape_of %14477 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14479 = shape.num_elements %14478 : tensor<3xindex> -> index + %14480 = stablehlo.compute_reshape_shape %14479, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14481 = stablehlo.dynamic_reshape %14477, %14480 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14482 = stablehlo.dot %14481, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14483 = stablehlo.logistic %14482 : tensor + %14484 = shape.shape_of %14483 : tensor -> tensor<2xindex> + %14485 = shape.shape_of %14482 : tensor -> tensor<2xindex> + %14486 = shape.cstr_broadcastable %14484, %14485 : tensor<2xindex>, tensor<2xindex> + %14487 = shape.assuming %14486 -> (tensor) { + %19688 = shape.broadcast %14484, %14485 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14483, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14482, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14488 = shape.shape_of %14487 : tensor -> tensor<2xindex> + %14489 = shape.cstr_broadcastable %14488, %14485 : tensor<2xindex>, tensor<2xindex> + %14490 = shape.assuming %14489 -> (tensor) { + %19688 = shape.broadcast %14488, %14485 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14487, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14482, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14491 = stablehlo.dot %14490, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5090 = tensor.dim %14463, %c0 : tensor + %14492 = arith.index_cast %dim_5090 : index to i64 + %from_elements_5091 = tensor.from_elements %14492, %c1_i64 : tensor<2xi64> + %14493 = stablehlo.dynamic_reshape %14463, %from_elements_5091 : (tensor, tensor<2xi64>) -> tensor + %dim_5092 = tensor.dim %14460, %c0 : tensor + %14494 = arith.index_cast %dim_5092 : index to i64 + %from_elements_5093 = tensor.from_elements %14494, %c1_i64 : tensor<2xi64> + %14495 = stablehlo.dynamic_reshape %14460, %from_elements_5093 : (tensor, tensor<2xi64>) -> tensor + %14496 = stablehlo.concatenate %14493, %14495, dim = 1 : (tensor, tensor) -> tensor + %14497 = "stablehlo.gather"(%14302, %14496) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14498 = shape.shape_of %14491 : tensor -> tensor<2xindex> + %14499 = shape.shape_of %14497 : tensor -> tensor<2xindex> + %14500 = shape.cstr_broadcastable %14498, %14499 : tensor<2xindex>, tensor<2xindex> + %14501 = shape.assuming %14500 -> (tensor) { + %19688 = shape.broadcast %14498, %14499 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14491, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14497, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14502 = shape.shape_of %14501 : tensor -> tensor<2xindex> + %14503 = stablehlo.dynamic_broadcast_in_dim %14501, %14502, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14504 = stablehlo.dynamic_broadcast_in_dim %213, %14502, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14505 = stablehlo.multiply %14503, %14504 : tensor + %dim_5094 = tensor.dim %14465, %c0 : tensor + %14506 = arith.index_cast %dim_5094 : index to i64 + %dim_5095 = tensor.dim %14501, %c0 : tensor + %14507 = arith.index_cast %dim_5095 : index to i64 + %14508 = arith.maxsi %14506, %14507 : i64 + %14509 = arith.index_cast %14508 : i64 to index + %from_elements_5096 = tensor.from_elements %14509, %c4096 : tensor<2xindex> + %14510 = stablehlo.dynamic_broadcast_in_dim %14465, %from_elements_5096, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5097 = tensor.dim %14510, %c0 : tensor + %14511 = arith.index_cast %dim_5097 : index to i64 + %from_elements_5098 = tensor.from_elements %14511, %c4096_i64 : tensor<2xi64> + %14512 = stablehlo.real_dynamic_slice %14505, %c_22, %from_elements_5098, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5099 = tensor.from_elements %14511, %c4096_i64, %c1_i64 : tensor<3xi64> + %14513 = stablehlo.dynamic_reshape %14510, %from_elements_5099 : (tensor, tensor<3xi64>) -> tensor + %14514 = stablehlo.dynamic_iota %from_elements_5099, dim = 1 : (tensor<3xi64>) -> tensor + %14515 = stablehlo.concatenate %14513, %14514, dim = 2 : (tensor, tensor) -> tensor + %14516 = "stablehlo.scatter"(%14453, %14515, %14512) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14517 = stablehlo.slice %14262 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14518 = stablehlo.reshape %14517 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14519 = stablehlo.custom_call @byteir.non_zero(%14518) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5100 = tensor.dim %14519, %c0 : tensor + %14520 = arith.index_cast %dim_5100 : index to i64 + %from_elements_5101 = tensor.from_elements %14520, %c1_i64 : tensor<2xi64> + %14521 = stablehlo.real_dynamic_slice %14519, %c_22, %from_elements_5101, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5102 = tensor.dim %14521, %c0 : tensor + %14522 = arith.index_cast %dim_5102 : index to i64 + %from_elements_5103 = tensor.from_elements %14522 : tensor<1xi64> + %14523 = stablehlo.dynamic_reshape %14521, %from_elements_5103 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5104 = tensor.from_elements %14520, %c2_i64 : tensor<2xi64> + %14524 = stablehlo.real_dynamic_slice %14519, %c_24, %from_elements_5104, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5105 = tensor.dim %14524, %c0 : tensor + %14525 = arith.index_cast %dim_5105 : index to i64 + %from_elements_5106 = tensor.from_elements %14525 : tensor<1xi64> + %14526 = stablehlo.dynamic_reshape %14524, %from_elements_5106 : (tensor, tensor<1xi64>) -> tensor + %dim_5107 = tensor.dim %14526, %c0 : tensor + %14527 = arith.index_cast %dim_5107 : index to i64 + %from_elements_5108 = tensor.from_elements %14527, %c1_i64 : tensor<2xi64> + %14528 = stablehlo.dynamic_reshape %14526, %from_elements_5108 : (tensor, tensor<2xi64>) -> tensor + %dim_5109 = tensor.dim %14528, %c0 : tensor + %14529 = arith.index_cast %dim_5109 : index to i64 + %from_elements_5110 = tensor.from_elements %c1_i64, %14529, %c4096_i64 : tensor<3xi64> + %14530 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5110, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5111 = tensor.dim %14530, %c1 : tensor<1x?x4096xi64> + %14531 = arith.index_cast %dim_5111 : index to i64 + %from_elements_5112 = tensor.from_elements %c1_i64, %14531, %c4096_i64, %c1_i64 : tensor<4xi64> + %14532 = stablehlo.dynamic_reshape %14530, %from_elements_5112 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14533 = stablehlo.dynamic_broadcast_in_dim %14528, %from_elements_5110, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5113 = tensor.dim %14533, %c1 : tensor<1x?x4096xi64> + %14534 = arith.index_cast %dim_5113 : index to i64 + %from_elements_5114 = tensor.from_elements %c1_i64, %14534, %c4096_i64, %c1_i64 : tensor<4xi64> + %14535 = stablehlo.dynamic_reshape %14533, %from_elements_5114 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14536 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5110, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5115 = tensor.dim %14536, %c1 : tensor<1x?x4096xi64> + %14537 = arith.index_cast %dim_5115 : index to i64 + %from_elements_5116 = tensor.from_elements %c1_i64, %14537, %c4096_i64, %c1_i64 : tensor<4xi64> + %14538 = stablehlo.dynamic_reshape %14536, %from_elements_5116 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14539 = stablehlo.concatenate %14532, %14535, %14538, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14540 = "stablehlo.gather"(%14273, %14539) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14541 = shape.shape_of %14540 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14542 = shape.num_elements %14541 : tensor<3xindex> -> index + %14543 = stablehlo.compute_reshape_shape %14542, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14544 = stablehlo.dynamic_reshape %14540, %14543 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14545 = stablehlo.dot %14544, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14546 = stablehlo.logistic %14545 : tensor + %14547 = shape.shape_of %14546 : tensor -> tensor<2xindex> + %14548 = shape.shape_of %14545 : tensor -> tensor<2xindex> + %14549 = shape.cstr_broadcastable %14547, %14548 : tensor<2xindex>, tensor<2xindex> + %14550 = shape.assuming %14549 -> (tensor) { + %19688 = shape.broadcast %14547, %14548 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14546, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14545, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14551 = shape.shape_of %14550 : tensor -> tensor<2xindex> + %14552 = shape.cstr_broadcastable %14551, %14548 : tensor<2xindex>, tensor<2xindex> + %14553 = shape.assuming %14552 -> (tensor) { + %19688 = shape.broadcast %14551, %14548 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14550, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14545, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14554 = stablehlo.dot %14553, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5117 = tensor.dim %14526, %c0 : tensor + %14555 = arith.index_cast %dim_5117 : index to i64 + %from_elements_5118 = tensor.from_elements %14555, %c1_i64 : tensor<2xi64> + %14556 = stablehlo.dynamic_reshape %14526, %from_elements_5118 : (tensor, tensor<2xi64>) -> tensor + %dim_5119 = tensor.dim %14523, %c0 : tensor + %14557 = arith.index_cast %dim_5119 : index to i64 + %from_elements_5120 = tensor.from_elements %14557, %c1_i64 : tensor<2xi64> + %14558 = stablehlo.dynamic_reshape %14523, %from_elements_5120 : (tensor, tensor<2xi64>) -> tensor + %14559 = stablehlo.concatenate %14556, %14558, dim = 1 : (tensor, tensor) -> tensor + %14560 = "stablehlo.gather"(%14302, %14559) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14561 = shape.shape_of %14554 : tensor -> tensor<2xindex> + %14562 = shape.shape_of %14560 : tensor -> tensor<2xindex> + %14563 = shape.cstr_broadcastable %14561, %14562 : tensor<2xindex>, tensor<2xindex> + %14564 = shape.assuming %14563 -> (tensor) { + %19688 = shape.broadcast %14561, %14562 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14554, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14560, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14565 = shape.shape_of %14564 : tensor -> tensor<2xindex> + %14566 = stablehlo.dynamic_broadcast_in_dim %14564, %14565, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14567 = stablehlo.dynamic_broadcast_in_dim %213, %14565, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14568 = stablehlo.multiply %14566, %14567 : tensor + %dim_5121 = tensor.dim %14528, %c0 : tensor + %14569 = arith.index_cast %dim_5121 : index to i64 + %dim_5122 = tensor.dim %14564, %c0 : tensor + %14570 = arith.index_cast %dim_5122 : index to i64 + %14571 = arith.maxsi %14569, %14570 : i64 + %14572 = arith.index_cast %14571 : i64 to index + %from_elements_5123 = tensor.from_elements %14572, %c4096 : tensor<2xindex> + %14573 = stablehlo.dynamic_broadcast_in_dim %14528, %from_elements_5123, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5124 = tensor.dim %14573, %c0 : tensor + %14574 = arith.index_cast %dim_5124 : index to i64 + %from_elements_5125 = tensor.from_elements %14574, %c4096_i64 : tensor<2xi64> + %14575 = stablehlo.real_dynamic_slice %14568, %c_22, %from_elements_5125, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5126 = tensor.from_elements %14574, %c4096_i64, %c1_i64 : tensor<3xi64> + %14576 = stablehlo.dynamic_reshape %14573, %from_elements_5126 : (tensor, tensor<3xi64>) -> tensor + %14577 = stablehlo.dynamic_iota %from_elements_5126, dim = 1 : (tensor<3xi64>) -> tensor + %14578 = stablehlo.concatenate %14576, %14577, dim = 2 : (tensor, tensor) -> tensor + %14579 = "stablehlo.scatter"(%14516, %14578, %14575) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14580 = stablehlo.slice %14262 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14581 = stablehlo.reshape %14580 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14582 = stablehlo.custom_call @byteir.non_zero(%14581) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5127 = tensor.dim %14582, %c0 : tensor + %14583 = arith.index_cast %dim_5127 : index to i64 + %from_elements_5128 = tensor.from_elements %14583, %c1_i64 : tensor<2xi64> + %14584 = stablehlo.real_dynamic_slice %14582, %c_22, %from_elements_5128, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5129 = tensor.dim %14584, %c0 : tensor + %14585 = arith.index_cast %dim_5129 : index to i64 + %from_elements_5130 = tensor.from_elements %14585 : tensor<1xi64> + %14586 = stablehlo.dynamic_reshape %14584, %from_elements_5130 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5131 = tensor.from_elements %14583, %c2_i64 : tensor<2xi64> + %14587 = stablehlo.real_dynamic_slice %14582, %c_24, %from_elements_5131, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5132 = tensor.dim %14587, %c0 : tensor + %14588 = arith.index_cast %dim_5132 : index to i64 + %from_elements_5133 = tensor.from_elements %14588 : tensor<1xi64> + %14589 = stablehlo.dynamic_reshape %14587, %from_elements_5133 : (tensor, tensor<1xi64>) -> tensor + %dim_5134 = tensor.dim %14589, %c0 : tensor + %14590 = arith.index_cast %dim_5134 : index to i64 + %from_elements_5135 = tensor.from_elements %14590, %c1_i64 : tensor<2xi64> + %14591 = stablehlo.dynamic_reshape %14589, %from_elements_5135 : (tensor, tensor<2xi64>) -> tensor + %dim_5136 = tensor.dim %14591, %c0 : tensor + %14592 = arith.index_cast %dim_5136 : index to i64 + %from_elements_5137 = tensor.from_elements %c1_i64, %14592, %c4096_i64 : tensor<3xi64> + %14593 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5137, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5138 = tensor.dim %14593, %c1 : tensor<1x?x4096xi64> + %14594 = arith.index_cast %dim_5138 : index to i64 + %from_elements_5139 = tensor.from_elements %c1_i64, %14594, %c4096_i64, %c1_i64 : tensor<4xi64> + %14595 = stablehlo.dynamic_reshape %14593, %from_elements_5139 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14596 = stablehlo.dynamic_broadcast_in_dim %14591, %from_elements_5137, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5140 = tensor.dim %14596, %c1 : tensor<1x?x4096xi64> + %14597 = arith.index_cast %dim_5140 : index to i64 + %from_elements_5141 = tensor.from_elements %c1_i64, %14597, %c4096_i64, %c1_i64 : tensor<4xi64> + %14598 = stablehlo.dynamic_reshape %14596, %from_elements_5141 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14599 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5137, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5142 = tensor.dim %14599, %c1 : tensor<1x?x4096xi64> + %14600 = arith.index_cast %dim_5142 : index to i64 + %from_elements_5143 = tensor.from_elements %c1_i64, %14600, %c4096_i64, %c1_i64 : tensor<4xi64> + %14601 = stablehlo.dynamic_reshape %14599, %from_elements_5143 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14602 = stablehlo.concatenate %14595, %14598, %14601, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14603 = "stablehlo.gather"(%14273, %14602) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14604 = shape.shape_of %14603 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14605 = shape.num_elements %14604 : tensor<3xindex> -> index + %14606 = stablehlo.compute_reshape_shape %14605, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14607 = stablehlo.dynamic_reshape %14603, %14606 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14608 = stablehlo.dot %14607, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14609 = stablehlo.logistic %14608 : tensor + %14610 = shape.shape_of %14609 : tensor -> tensor<2xindex> + %14611 = shape.shape_of %14608 : tensor -> tensor<2xindex> + %14612 = shape.cstr_broadcastable %14610, %14611 : tensor<2xindex>, tensor<2xindex> + %14613 = shape.assuming %14612 -> (tensor) { + %19688 = shape.broadcast %14610, %14611 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14609, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14608, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14614 = shape.shape_of %14613 : tensor -> tensor<2xindex> + %14615 = shape.cstr_broadcastable %14614, %14611 : tensor<2xindex>, tensor<2xindex> + %14616 = shape.assuming %14615 -> (tensor) { + %19688 = shape.broadcast %14614, %14611 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14613, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14608, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14617 = stablehlo.dot %14616, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5144 = tensor.dim %14589, %c0 : tensor + %14618 = arith.index_cast %dim_5144 : index to i64 + %from_elements_5145 = tensor.from_elements %14618, %c1_i64 : tensor<2xi64> + %14619 = stablehlo.dynamic_reshape %14589, %from_elements_5145 : (tensor, tensor<2xi64>) -> tensor + %dim_5146 = tensor.dim %14586, %c0 : tensor + %14620 = arith.index_cast %dim_5146 : index to i64 + %from_elements_5147 = tensor.from_elements %14620, %c1_i64 : tensor<2xi64> + %14621 = stablehlo.dynamic_reshape %14586, %from_elements_5147 : (tensor, tensor<2xi64>) -> tensor + %14622 = stablehlo.concatenate %14619, %14621, dim = 1 : (tensor, tensor) -> tensor + %14623 = "stablehlo.gather"(%14302, %14622) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14624 = shape.shape_of %14617 : tensor -> tensor<2xindex> + %14625 = shape.shape_of %14623 : tensor -> tensor<2xindex> + %14626 = shape.cstr_broadcastable %14624, %14625 : tensor<2xindex>, tensor<2xindex> + %14627 = shape.assuming %14626 -> (tensor) { + %19688 = shape.broadcast %14624, %14625 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14617, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14623, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14628 = shape.shape_of %14627 : tensor -> tensor<2xindex> + %14629 = stablehlo.dynamic_broadcast_in_dim %14627, %14628, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14630 = stablehlo.dynamic_broadcast_in_dim %213, %14628, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14631 = stablehlo.multiply %14629, %14630 : tensor + %dim_5148 = tensor.dim %14591, %c0 : tensor + %14632 = arith.index_cast %dim_5148 : index to i64 + %dim_5149 = tensor.dim %14627, %c0 : tensor + %14633 = arith.index_cast %dim_5149 : index to i64 + %14634 = arith.maxsi %14632, %14633 : i64 + %14635 = arith.index_cast %14634 : i64 to index + %from_elements_5150 = tensor.from_elements %14635, %c4096 : tensor<2xindex> + %14636 = stablehlo.dynamic_broadcast_in_dim %14591, %from_elements_5150, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5151 = tensor.dim %14636, %c0 : tensor + %14637 = arith.index_cast %dim_5151 : index to i64 + %from_elements_5152 = tensor.from_elements %14637, %c4096_i64 : tensor<2xi64> + %14638 = stablehlo.real_dynamic_slice %14631, %c_22, %from_elements_5152, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5153 = tensor.from_elements %14637, %c4096_i64, %c1_i64 : tensor<3xi64> + %14639 = stablehlo.dynamic_reshape %14636, %from_elements_5153 : (tensor, tensor<3xi64>) -> tensor + %14640 = stablehlo.dynamic_iota %from_elements_5153, dim = 1 : (tensor<3xi64>) -> tensor + %14641 = stablehlo.concatenate %14639, %14640, dim = 2 : (tensor, tensor) -> tensor + %14642 = "stablehlo.scatter"(%14579, %14641, %14638) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14643 = stablehlo.slice %14262 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14644 = stablehlo.reshape %14643 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14645 = stablehlo.custom_call @byteir.non_zero(%14644) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5154 = tensor.dim %14645, %c0 : tensor + %14646 = arith.index_cast %dim_5154 : index to i64 + %from_elements_5155 = tensor.from_elements %14646, %c1_i64 : tensor<2xi64> + %14647 = stablehlo.real_dynamic_slice %14645, %c_22, %from_elements_5155, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5156 = tensor.dim %14647, %c0 : tensor + %14648 = arith.index_cast %dim_5156 : index to i64 + %from_elements_5157 = tensor.from_elements %14648 : tensor<1xi64> + %14649 = stablehlo.dynamic_reshape %14647, %from_elements_5157 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5158 = tensor.from_elements %14646, %c2_i64 : tensor<2xi64> + %14650 = stablehlo.real_dynamic_slice %14645, %c_24, %from_elements_5158, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5159 = tensor.dim %14650, %c0 : tensor + %14651 = arith.index_cast %dim_5159 : index to i64 + %from_elements_5160 = tensor.from_elements %14651 : tensor<1xi64> + %14652 = stablehlo.dynamic_reshape %14650, %from_elements_5160 : (tensor, tensor<1xi64>) -> tensor + %dim_5161 = tensor.dim %14652, %c0 : tensor + %14653 = arith.index_cast %dim_5161 : index to i64 + %from_elements_5162 = tensor.from_elements %14653, %c1_i64 : tensor<2xi64> + %14654 = stablehlo.dynamic_reshape %14652, %from_elements_5162 : (tensor, tensor<2xi64>) -> tensor + %dim_5163 = tensor.dim %14654, %c0 : tensor + %14655 = arith.index_cast %dim_5163 : index to i64 + %from_elements_5164 = tensor.from_elements %c1_i64, %14655, %c4096_i64 : tensor<3xi64> + %14656 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5164, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5165 = tensor.dim %14656, %c1 : tensor<1x?x4096xi64> + %14657 = arith.index_cast %dim_5165 : index to i64 + %from_elements_5166 = tensor.from_elements %c1_i64, %14657, %c4096_i64, %c1_i64 : tensor<4xi64> + %14658 = stablehlo.dynamic_reshape %14656, %from_elements_5166 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14659 = stablehlo.dynamic_broadcast_in_dim %14654, %from_elements_5164, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5167 = tensor.dim %14659, %c1 : tensor<1x?x4096xi64> + %14660 = arith.index_cast %dim_5167 : index to i64 + %from_elements_5168 = tensor.from_elements %c1_i64, %14660, %c4096_i64, %c1_i64 : tensor<4xi64> + %14661 = stablehlo.dynamic_reshape %14659, %from_elements_5168 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14662 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5164, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5169 = tensor.dim %14662, %c1 : tensor<1x?x4096xi64> + %14663 = arith.index_cast %dim_5169 : index to i64 + %from_elements_5170 = tensor.from_elements %c1_i64, %14663, %c4096_i64, %c1_i64 : tensor<4xi64> + %14664 = stablehlo.dynamic_reshape %14662, %from_elements_5170 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14665 = stablehlo.concatenate %14658, %14661, %14664, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14666 = "stablehlo.gather"(%14273, %14665) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14667 = shape.shape_of %14666 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14668 = shape.num_elements %14667 : tensor<3xindex> -> index + %14669 = stablehlo.compute_reshape_shape %14668, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14670 = stablehlo.dynamic_reshape %14666, %14669 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14671 = stablehlo.dot %14670, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14672 = stablehlo.logistic %14671 : tensor + %14673 = shape.shape_of %14672 : tensor -> tensor<2xindex> + %14674 = shape.shape_of %14671 : tensor -> tensor<2xindex> + %14675 = shape.cstr_broadcastable %14673, %14674 : tensor<2xindex>, tensor<2xindex> + %14676 = shape.assuming %14675 -> (tensor) { + %19688 = shape.broadcast %14673, %14674 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14672, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14671, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14677 = shape.shape_of %14676 : tensor -> tensor<2xindex> + %14678 = shape.cstr_broadcastable %14677, %14674 : tensor<2xindex>, tensor<2xindex> + %14679 = shape.assuming %14678 -> (tensor) { + %19688 = shape.broadcast %14677, %14674 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14676, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14671, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14680 = stablehlo.dot %14679, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5171 = tensor.dim %14652, %c0 : tensor + %14681 = arith.index_cast %dim_5171 : index to i64 + %from_elements_5172 = tensor.from_elements %14681, %c1_i64 : tensor<2xi64> + %14682 = stablehlo.dynamic_reshape %14652, %from_elements_5172 : (tensor, tensor<2xi64>) -> tensor + %dim_5173 = tensor.dim %14649, %c0 : tensor + %14683 = arith.index_cast %dim_5173 : index to i64 + %from_elements_5174 = tensor.from_elements %14683, %c1_i64 : tensor<2xi64> + %14684 = stablehlo.dynamic_reshape %14649, %from_elements_5174 : (tensor, tensor<2xi64>) -> tensor + %14685 = stablehlo.concatenate %14682, %14684, dim = 1 : (tensor, tensor) -> tensor + %14686 = "stablehlo.gather"(%14302, %14685) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14687 = shape.shape_of %14680 : tensor -> tensor<2xindex> + %14688 = shape.shape_of %14686 : tensor -> tensor<2xindex> + %14689 = shape.cstr_broadcastable %14687, %14688 : tensor<2xindex>, tensor<2xindex> + %14690 = shape.assuming %14689 -> (tensor) { + %19688 = shape.broadcast %14687, %14688 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14680, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14686, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14691 = shape.shape_of %14690 : tensor -> tensor<2xindex> + %14692 = stablehlo.dynamic_broadcast_in_dim %14690, %14691, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14693 = stablehlo.dynamic_broadcast_in_dim %213, %14691, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14694 = stablehlo.multiply %14692, %14693 : tensor + %dim_5175 = tensor.dim %14654, %c0 : tensor + %14695 = arith.index_cast %dim_5175 : index to i64 + %dim_5176 = tensor.dim %14690, %c0 : tensor + %14696 = arith.index_cast %dim_5176 : index to i64 + %14697 = arith.maxsi %14695, %14696 : i64 + %14698 = arith.index_cast %14697 : i64 to index + %from_elements_5177 = tensor.from_elements %14698, %c4096 : tensor<2xindex> + %14699 = stablehlo.dynamic_broadcast_in_dim %14654, %from_elements_5177, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5178 = tensor.dim %14699, %c0 : tensor + %14700 = arith.index_cast %dim_5178 : index to i64 + %from_elements_5179 = tensor.from_elements %14700, %c4096_i64 : tensor<2xi64> + %14701 = stablehlo.real_dynamic_slice %14694, %c_22, %from_elements_5179, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5180 = tensor.from_elements %14700, %c4096_i64, %c1_i64 : tensor<3xi64> + %14702 = stablehlo.dynamic_reshape %14699, %from_elements_5180 : (tensor, tensor<3xi64>) -> tensor + %14703 = stablehlo.dynamic_iota %from_elements_5180, dim = 1 : (tensor<3xi64>) -> tensor + %14704 = stablehlo.concatenate %14702, %14703, dim = 2 : (tensor, tensor) -> tensor + %14705 = "stablehlo.scatter"(%14642, %14704, %14701) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14706 = stablehlo.slice %14262 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14707 = stablehlo.reshape %14706 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14708 = stablehlo.custom_call @byteir.non_zero(%14707) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5181 = tensor.dim %14708, %c0 : tensor + %14709 = arith.index_cast %dim_5181 : index to i64 + %from_elements_5182 = tensor.from_elements %14709, %c1_i64 : tensor<2xi64> + %14710 = stablehlo.real_dynamic_slice %14708, %c_22, %from_elements_5182, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5183 = tensor.dim %14710, %c0 : tensor + %14711 = arith.index_cast %dim_5183 : index to i64 + %from_elements_5184 = tensor.from_elements %14711 : tensor<1xi64> + %14712 = stablehlo.dynamic_reshape %14710, %from_elements_5184 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5185 = tensor.from_elements %14709, %c2_i64 : tensor<2xi64> + %14713 = stablehlo.real_dynamic_slice %14708, %c_24, %from_elements_5185, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5186 = tensor.dim %14713, %c0 : tensor + %14714 = arith.index_cast %dim_5186 : index to i64 + %from_elements_5187 = tensor.from_elements %14714 : tensor<1xi64> + %14715 = stablehlo.dynamic_reshape %14713, %from_elements_5187 : (tensor, tensor<1xi64>) -> tensor + %dim_5188 = tensor.dim %14715, %c0 : tensor + %14716 = arith.index_cast %dim_5188 : index to i64 + %from_elements_5189 = tensor.from_elements %14716, %c1_i64 : tensor<2xi64> + %14717 = stablehlo.dynamic_reshape %14715, %from_elements_5189 : (tensor, tensor<2xi64>) -> tensor + %dim_5190 = tensor.dim %14717, %c0 : tensor + %14718 = arith.index_cast %dim_5190 : index to i64 + %from_elements_5191 = tensor.from_elements %c1_i64, %14718, %c4096_i64 : tensor<3xi64> + %14719 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5191, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5192 = tensor.dim %14719, %c1 : tensor<1x?x4096xi64> + %14720 = arith.index_cast %dim_5192 : index to i64 + %from_elements_5193 = tensor.from_elements %c1_i64, %14720, %c4096_i64, %c1_i64 : tensor<4xi64> + %14721 = stablehlo.dynamic_reshape %14719, %from_elements_5193 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14722 = stablehlo.dynamic_broadcast_in_dim %14717, %from_elements_5191, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5194 = tensor.dim %14722, %c1 : tensor<1x?x4096xi64> + %14723 = arith.index_cast %dim_5194 : index to i64 + %from_elements_5195 = tensor.from_elements %c1_i64, %14723, %c4096_i64, %c1_i64 : tensor<4xi64> + %14724 = stablehlo.dynamic_reshape %14722, %from_elements_5195 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14725 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5191, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5196 = tensor.dim %14725, %c1 : tensor<1x?x4096xi64> + %14726 = arith.index_cast %dim_5196 : index to i64 + %from_elements_5197 = tensor.from_elements %c1_i64, %14726, %c4096_i64, %c1_i64 : tensor<4xi64> + %14727 = stablehlo.dynamic_reshape %14725, %from_elements_5197 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14728 = stablehlo.concatenate %14721, %14724, %14727, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14729 = "stablehlo.gather"(%14273, %14728) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14730 = shape.shape_of %14729 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14731 = shape.num_elements %14730 : tensor<3xindex> -> index + %14732 = stablehlo.compute_reshape_shape %14731, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14733 = stablehlo.dynamic_reshape %14729, %14732 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14734 = stablehlo.dot %14733, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14735 = stablehlo.logistic %14734 : tensor + %14736 = shape.shape_of %14735 : tensor -> tensor<2xindex> + %14737 = shape.shape_of %14734 : tensor -> tensor<2xindex> + %14738 = shape.cstr_broadcastable %14736, %14737 : tensor<2xindex>, tensor<2xindex> + %14739 = shape.assuming %14738 -> (tensor) { + %19688 = shape.broadcast %14736, %14737 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14735, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14734, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14740 = shape.shape_of %14739 : tensor -> tensor<2xindex> + %14741 = shape.cstr_broadcastable %14740, %14737 : tensor<2xindex>, tensor<2xindex> + %14742 = shape.assuming %14741 -> (tensor) { + %19688 = shape.broadcast %14740, %14737 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14739, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14734, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14743 = stablehlo.dot %14742, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5198 = tensor.dim %14715, %c0 : tensor + %14744 = arith.index_cast %dim_5198 : index to i64 + %from_elements_5199 = tensor.from_elements %14744, %c1_i64 : tensor<2xi64> + %14745 = stablehlo.dynamic_reshape %14715, %from_elements_5199 : (tensor, tensor<2xi64>) -> tensor + %dim_5200 = tensor.dim %14712, %c0 : tensor + %14746 = arith.index_cast %dim_5200 : index to i64 + %from_elements_5201 = tensor.from_elements %14746, %c1_i64 : tensor<2xi64> + %14747 = stablehlo.dynamic_reshape %14712, %from_elements_5201 : (tensor, tensor<2xi64>) -> tensor + %14748 = stablehlo.concatenate %14745, %14747, dim = 1 : (tensor, tensor) -> tensor + %14749 = "stablehlo.gather"(%14302, %14748) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14750 = shape.shape_of %14743 : tensor -> tensor<2xindex> + %14751 = shape.shape_of %14749 : tensor -> tensor<2xindex> + %14752 = shape.cstr_broadcastable %14750, %14751 : tensor<2xindex>, tensor<2xindex> + %14753 = shape.assuming %14752 -> (tensor) { + %19688 = shape.broadcast %14750, %14751 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14743, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14749, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14754 = shape.shape_of %14753 : tensor -> tensor<2xindex> + %14755 = stablehlo.dynamic_broadcast_in_dim %14753, %14754, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14756 = stablehlo.dynamic_broadcast_in_dim %213, %14754, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14757 = stablehlo.multiply %14755, %14756 : tensor + %dim_5202 = tensor.dim %14717, %c0 : tensor + %14758 = arith.index_cast %dim_5202 : index to i64 + %dim_5203 = tensor.dim %14753, %c0 : tensor + %14759 = arith.index_cast %dim_5203 : index to i64 + %14760 = arith.maxsi %14758, %14759 : i64 + %14761 = arith.index_cast %14760 : i64 to index + %from_elements_5204 = tensor.from_elements %14761, %c4096 : tensor<2xindex> + %14762 = stablehlo.dynamic_broadcast_in_dim %14717, %from_elements_5204, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5205 = tensor.dim %14762, %c0 : tensor + %14763 = arith.index_cast %dim_5205 : index to i64 + %from_elements_5206 = tensor.from_elements %14763, %c4096_i64 : tensor<2xi64> + %14764 = stablehlo.real_dynamic_slice %14757, %c_22, %from_elements_5206, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5207 = tensor.from_elements %14763, %c4096_i64, %c1_i64 : tensor<3xi64> + %14765 = stablehlo.dynamic_reshape %14762, %from_elements_5207 : (tensor, tensor<3xi64>) -> tensor + %14766 = stablehlo.dynamic_iota %from_elements_5207, dim = 1 : (tensor<3xi64>) -> tensor + %14767 = stablehlo.concatenate %14765, %14766, dim = 2 : (tensor, tensor) -> tensor + %14768 = "stablehlo.scatter"(%14705, %14767, %14764) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14769 = stablehlo.reshape %14768 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %14770 = stablehlo.add %14235, %14769 : tensor<3x1x4096xf32> + %14771 = stablehlo.broadcast_in_dim %14770, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %14772 = stablehlo.power %14771, %15 : tensor<3x1x4096xf32> + %14773 = stablehlo.reduce(%14772 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %14774 = stablehlo.reshape %14773 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %14775 = stablehlo.broadcast_in_dim %14774, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %14776 = stablehlo.divide %14775, %21 : tensor<3x1x1xf32> + %14777 = stablehlo.broadcast_in_dim %14776, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %14778 = stablehlo.add %14777, %25 : tensor<3x1x1xf32> + %14779 = stablehlo.rsqrt %14778 : tensor<3x1x1xf32> + %14780 = stablehlo.broadcast_in_dim %14779, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %14781 = stablehlo.multiply %14771, %14780 : tensor<3x1x4096xf32> + %14782 = stablehlo.broadcast_in_dim %14781, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %14783 = stablehlo.multiply %14782, %31 : tensor<3x1x4096xf32> + %14784 = stablehlo.reshape %14783 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %14785 = stablehlo.dot %14784, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %14786 = stablehlo.reshape %14785 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %14787 = stablehlo.dot %14784, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %14788 = stablehlo.reshape %14787 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %14789 = stablehlo.reshape %14786 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %14790 = stablehlo.transpose %14789, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %14791 = stablehlo.reshape %14788 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %14792 = stablehlo.transpose %14791, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %14793 = stablehlo.slice %arg48 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %14794 = stablehlo.slice %arg49 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %14795 = "stablehlo.gather"(%14793, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %14796 = stablehlo.reshape %14795 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %14797 = "stablehlo.gather"(%14794, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %14798 = stablehlo.reshape %14797 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %14799 = stablehlo.broadcast_in_dim %14790, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %14800 = stablehlo.broadcast_in_dim %14796, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %14801 = stablehlo.multiply %14799, %14800 : tensor<3x32x1x128xf32> + %14802 = stablehlo.slice %14790 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %14803 = stablehlo.slice %14790 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %14804 = stablehlo.negate %14803 : tensor<3x32x1x64xf32> + %14805 = stablehlo.concatenate %14804, %14802, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %14806 = stablehlo.broadcast_in_dim %14805, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %14807 = stablehlo.broadcast_in_dim %14798, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %14808 = stablehlo.multiply %14806, %14807 : tensor<3x32x1x128xf32> + %14809 = stablehlo.add %14801, %14808 : tensor<3x32x1x128xf32> + %14810 = stablehlo.broadcast_in_dim %14792, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %14811 = stablehlo.broadcast_in_dim %14796, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %14812 = stablehlo.multiply %14810, %14811 : tensor<3x8x1x128xf32> + %14813 = stablehlo.slice %14792 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %14814 = stablehlo.slice %14792 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %14815 = stablehlo.negate %14814 : tensor<3x8x1x64xf32> + %14816 = stablehlo.concatenate %14815, %14813, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %14817 = stablehlo.broadcast_in_dim %14816, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %14818 = stablehlo.broadcast_in_dim %14798, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %14819 = stablehlo.multiply %14817, %14818 : tensor<3x8x1x128xf32> + %14820 = stablehlo.add %14812, %14819 : tensor<3x8x1x128xf32> + %14821 = stablehlo.concatenate %arg113, %14820, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %14822 = stablehlo.concatenate %arg114, %14792, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %14823 = stablehlo.reshape %14821 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %14824 = stablehlo.broadcast_in_dim %14823, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %14825 = stablehlo.reshape %14824 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %14826 = stablehlo.reshape %14822 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %14827 = stablehlo.broadcast_in_dim %14826, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %14828 = stablehlo.reshape %14827 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %14829 = stablehlo.transpose %14825, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %14830 = stablehlo.reshape %14809 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %14831 = stablehlo.reshape %14829 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %14832 = stablehlo.broadcast_in_dim %14831, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %14833 = stablehlo.dot_general %14830, %14832, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %14834 = stablehlo.reshape %14833 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %14835 = stablehlo.broadcast_in_dim %14834, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %14836 = stablehlo.divide %14835, %89 : tensor<3x32x1x8xf32> + %14837 = stablehlo.custom_call @byteir.softmax(%14836) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %14838 = stablehlo.reshape %14837 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %14839 = stablehlo.reshape %14828 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %14840 = stablehlo.broadcast_in_dim %14839, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %14841 = stablehlo.dot_general %14838, %14840, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %14842 = stablehlo.reshape %14841 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %14843 = stablehlo.transpose %14842, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %14844 = stablehlo.reshape %14843 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %14845 = stablehlo.reshape %14844 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %14846 = stablehlo.dot %14845, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %14847 = stablehlo.reshape %14846 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %14848 = stablehlo.add %14770, %14847 : tensor<3x1x4096xf32> + %14849 = stablehlo.broadcast_in_dim %14848, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %14850 = stablehlo.power %14849, %15 : tensor<3x1x4096xf32> + %14851 = stablehlo.reduce(%14850 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %14852 = stablehlo.reshape %14851 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %14853 = stablehlo.broadcast_in_dim %14852, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %14854 = stablehlo.divide %14853, %21 : tensor<3x1x1xf32> + %14855 = stablehlo.broadcast_in_dim %14854, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %14856 = stablehlo.add %14855, %25 : tensor<3x1x1xf32> + %14857 = stablehlo.rsqrt %14856 : tensor<3x1x1xf32> + %14858 = stablehlo.broadcast_in_dim %14857, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %14859 = stablehlo.multiply %14849, %14858 : tensor<3x1x4096xf32> + %14860 = stablehlo.broadcast_in_dim %14859, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %14861 = stablehlo.multiply %14860, %31 : tensor<3x1x4096xf32> + %14862 = stablehlo.reshape %14861 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %14863 = stablehlo.dot %14862, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %14864 = stablehlo.custom_call @byteir.softmax(%14863) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %14865:2 = stablehlo.custom_call @byteir.top_k(%14864) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %14866 = stablehlo.reduce(%14865#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %14867 = stablehlo.reshape %14866 : (tensor<3xf32>) -> tensor<3x1xf32> + %14868 = stablehlo.broadcast_in_dim %14865#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %14869 = stablehlo.broadcast_in_dim %14867, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %14870 = stablehlo.divide %14868, %14869 : tensor<3x2xf32> + %14871 = stablehlo.reshape %14865#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %14872 = stablehlo.broadcast_in_dim %14871, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %14873 = stablehlo.compare EQ, %14872, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %14874 = stablehlo.convert %14873 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %14875 = stablehlo.transpose %14874, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %14876 = stablehlo.slice %14875 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14877 = stablehlo.reshape %14876 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14878 = stablehlo.custom_call @byteir.non_zero(%14877) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5208 = tensor.dim %14878, %c0 : tensor + %14879 = arith.index_cast %dim_5208 : index to i64 + %from_elements_5209 = tensor.from_elements %14879, %c1_i64 : tensor<2xi64> + %14880 = stablehlo.real_dynamic_slice %14878, %c_22, %from_elements_5209, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5210 = tensor.dim %14880, %c0 : tensor + %14881 = arith.index_cast %dim_5210 : index to i64 + %from_elements_5211 = tensor.from_elements %14881 : tensor<1xi64> + %14882 = stablehlo.dynamic_reshape %14880, %from_elements_5211 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5212 = tensor.from_elements %14879, %c2_i64 : tensor<2xi64> + %14883 = stablehlo.real_dynamic_slice %14878, %c_24, %from_elements_5212, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5213 = tensor.dim %14883, %c0 : tensor + %14884 = arith.index_cast %dim_5213 : index to i64 + %from_elements_5214 = tensor.from_elements %14884 : tensor<1xi64> + %14885 = stablehlo.dynamic_reshape %14883, %from_elements_5214 : (tensor, tensor<1xi64>) -> tensor + %14886 = stablehlo.reshape %14862 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_5215 = tensor.dim %14885, %c0 : tensor + %14887 = arith.index_cast %dim_5215 : index to i64 + %from_elements_5216 = tensor.from_elements %14887, %c1_i64 : tensor<2xi64> + %14888 = stablehlo.dynamic_reshape %14885, %from_elements_5216 : (tensor, tensor<2xi64>) -> tensor + %dim_5217 = tensor.dim %14888, %c0 : tensor + %14889 = arith.index_cast %dim_5217 : index to i64 + %from_elements_5218 = tensor.from_elements %c1_i64, %14889, %c4096_i64 : tensor<3xi64> + %14890 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5218, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5219 = tensor.dim %14890, %c1 : tensor<1x?x4096xi64> + %14891 = arith.index_cast %dim_5219 : index to i64 + %from_elements_5220 = tensor.from_elements %c1_i64, %14891, %c4096_i64, %c1_i64 : tensor<4xi64> + %14892 = stablehlo.dynamic_reshape %14890, %from_elements_5220 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14893 = stablehlo.dynamic_broadcast_in_dim %14888, %from_elements_5218, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5221 = tensor.dim %14893, %c1 : tensor<1x?x4096xi64> + %14894 = arith.index_cast %dim_5221 : index to i64 + %from_elements_5222 = tensor.from_elements %c1_i64, %14894, %c4096_i64, %c1_i64 : tensor<4xi64> + %14895 = stablehlo.dynamic_reshape %14893, %from_elements_5222 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14896 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5218, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5223 = tensor.dim %14896, %c1 : tensor<1x?x4096xi64> + %14897 = arith.index_cast %dim_5223 : index to i64 + %from_elements_5224 = tensor.from_elements %c1_i64, %14897, %c4096_i64, %c1_i64 : tensor<4xi64> + %14898 = stablehlo.dynamic_reshape %14896, %from_elements_5224 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14899 = stablehlo.concatenate %14892, %14895, %14898, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14900 = "stablehlo.gather"(%14886, %14899) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14901 = shape.shape_of %14900 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14902 = shape.num_elements %14901 : tensor<3xindex> -> index + %14903 = stablehlo.compute_reshape_shape %14902, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14904 = stablehlo.dynamic_reshape %14900, %14903 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14905 = stablehlo.dot %14904, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14906 = stablehlo.logistic %14905 : tensor + %14907 = shape.shape_of %14906 : tensor -> tensor<2xindex> + %14908 = shape.shape_of %14905 : tensor -> tensor<2xindex> + %14909 = shape.cstr_broadcastable %14907, %14908 : tensor<2xindex>, tensor<2xindex> + %14910 = shape.assuming %14909 -> (tensor) { + %19688 = shape.broadcast %14907, %14908 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14906, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14905, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14911 = shape.shape_of %14910 : tensor -> tensor<2xindex> + %14912 = shape.cstr_broadcastable %14911, %14908 : tensor<2xindex>, tensor<2xindex> + %14913 = shape.assuming %14912 -> (tensor) { + %19688 = shape.broadcast %14911, %14908 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14910, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14905, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14914 = stablehlo.dot %14913, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %14915 = stablehlo.reshape %14870 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_5225 = tensor.dim %14885, %c0 : tensor + %14916 = arith.index_cast %dim_5225 : index to i64 + %from_elements_5226 = tensor.from_elements %14916, %c1_i64 : tensor<2xi64> + %14917 = stablehlo.dynamic_reshape %14885, %from_elements_5226 : (tensor, tensor<2xi64>) -> tensor + %dim_5227 = tensor.dim %14882, %c0 : tensor + %14918 = arith.index_cast %dim_5227 : index to i64 + %from_elements_5228 = tensor.from_elements %14918, %c1_i64 : tensor<2xi64> + %14919 = stablehlo.dynamic_reshape %14882, %from_elements_5228 : (tensor, tensor<2xi64>) -> tensor + %14920 = stablehlo.concatenate %14917, %14919, dim = 1 : (tensor, tensor) -> tensor + %14921 = "stablehlo.gather"(%14915, %14920) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14922 = shape.shape_of %14914 : tensor -> tensor<2xindex> + %14923 = shape.shape_of %14921 : tensor -> tensor<2xindex> + %14924 = shape.cstr_broadcastable %14922, %14923 : tensor<2xindex>, tensor<2xindex> + %14925 = shape.assuming %14924 -> (tensor) { + %19688 = shape.broadcast %14922, %14923 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14914, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14921, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14926 = shape.shape_of %14925 : tensor -> tensor<2xindex> + %14927 = stablehlo.dynamic_broadcast_in_dim %14925, %14926, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14928 = stablehlo.dynamic_broadcast_in_dim %213, %14926, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14929 = stablehlo.multiply %14927, %14928 : tensor + %dim_5229 = tensor.dim %14888, %c0 : tensor + %14930 = arith.index_cast %dim_5229 : index to i64 + %dim_5230 = tensor.dim %14925, %c0 : tensor + %14931 = arith.index_cast %dim_5230 : index to i64 + %14932 = arith.maxsi %14930, %14931 : i64 + %14933 = arith.index_cast %14932 : i64 to index + %from_elements_5231 = tensor.from_elements %14933, %c4096 : tensor<2xindex> + %14934 = stablehlo.dynamic_broadcast_in_dim %14888, %from_elements_5231, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5232 = tensor.dim %14934, %c0 : tensor + %14935 = arith.index_cast %dim_5232 : index to i64 + %from_elements_5233 = tensor.from_elements %14935, %c4096_i64 : tensor<2xi64> + %14936 = stablehlo.real_dynamic_slice %14929, %c_22, %from_elements_5233, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5234 = tensor.from_elements %14935, %c4096_i64, %c1_i64 : tensor<3xi64> + %14937 = stablehlo.dynamic_reshape %14934, %from_elements_5234 : (tensor, tensor<3xi64>) -> tensor + %14938 = stablehlo.dynamic_iota %from_elements_5234, dim = 1 : (tensor<3xi64>) -> tensor + %14939 = stablehlo.concatenate %14937, %14938, dim = 2 : (tensor, tensor) -> tensor + %14940 = "stablehlo.scatter"(%cst_2, %14939, %14936) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %14941 = stablehlo.slice %14875 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %14942 = stablehlo.reshape %14941 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %14943 = stablehlo.custom_call @byteir.non_zero(%14942) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5235 = tensor.dim %14943, %c0 : tensor + %14944 = arith.index_cast %dim_5235 : index to i64 + %from_elements_5236 = tensor.from_elements %14944, %c1_i64 : tensor<2xi64> + %14945 = stablehlo.real_dynamic_slice %14943, %c_22, %from_elements_5236, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5237 = tensor.dim %14945, %c0 : tensor + %14946 = arith.index_cast %dim_5237 : index to i64 + %from_elements_5238 = tensor.from_elements %14946 : tensor<1xi64> + %14947 = stablehlo.dynamic_reshape %14945, %from_elements_5238 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5239 = tensor.from_elements %14944, %c2_i64 : tensor<2xi64> + %14948 = stablehlo.real_dynamic_slice %14943, %c_24, %from_elements_5239, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5240 = tensor.dim %14948, %c0 : tensor + %14949 = arith.index_cast %dim_5240 : index to i64 + %from_elements_5241 = tensor.from_elements %14949 : tensor<1xi64> + %14950 = stablehlo.dynamic_reshape %14948, %from_elements_5241 : (tensor, tensor<1xi64>) -> tensor + %dim_5242 = tensor.dim %14950, %c0 : tensor + %14951 = arith.index_cast %dim_5242 : index to i64 + %from_elements_5243 = tensor.from_elements %14951, %c1_i64 : tensor<2xi64> + %14952 = stablehlo.dynamic_reshape %14950, %from_elements_5243 : (tensor, tensor<2xi64>) -> tensor + %dim_5244 = tensor.dim %14952, %c0 : tensor + %14953 = arith.index_cast %dim_5244 : index to i64 + %from_elements_5245 = tensor.from_elements %c1_i64, %14953, %c4096_i64 : tensor<3xi64> + %14954 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5245, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5246 = tensor.dim %14954, %c1 : tensor<1x?x4096xi64> + %14955 = arith.index_cast %dim_5246 : index to i64 + %from_elements_5247 = tensor.from_elements %c1_i64, %14955, %c4096_i64, %c1_i64 : tensor<4xi64> + %14956 = stablehlo.dynamic_reshape %14954, %from_elements_5247 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14957 = stablehlo.dynamic_broadcast_in_dim %14952, %from_elements_5245, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5248 = tensor.dim %14957, %c1 : tensor<1x?x4096xi64> + %14958 = arith.index_cast %dim_5248 : index to i64 + %from_elements_5249 = tensor.from_elements %c1_i64, %14958, %c4096_i64, %c1_i64 : tensor<4xi64> + %14959 = stablehlo.dynamic_reshape %14957, %from_elements_5249 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14960 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5245, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5250 = tensor.dim %14960, %c1 : tensor<1x?x4096xi64> + %14961 = arith.index_cast %dim_5250 : index to i64 + %from_elements_5251 = tensor.from_elements %c1_i64, %14961, %c4096_i64, %c1_i64 : tensor<4xi64> + %14962 = stablehlo.dynamic_reshape %14960, %from_elements_5251 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %14963 = stablehlo.concatenate %14956, %14959, %14962, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %14964 = "stablehlo.gather"(%14886, %14963) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %14965 = shape.shape_of %14964 : tensor<1x?x4096xf32> -> tensor<3xindex> + %14966 = shape.num_elements %14965 : tensor<3xindex> -> index + %14967 = stablehlo.compute_reshape_shape %14966, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %14968 = stablehlo.dynamic_reshape %14964, %14967 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %14969 = stablehlo.dot %14968, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %14970 = stablehlo.logistic %14969 : tensor + %14971 = shape.shape_of %14970 : tensor -> tensor<2xindex> + %14972 = shape.shape_of %14969 : tensor -> tensor<2xindex> + %14973 = shape.cstr_broadcastable %14971, %14972 : tensor<2xindex>, tensor<2xindex> + %14974 = shape.assuming %14973 -> (tensor) { + %19688 = shape.broadcast %14971, %14972 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14970, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14969, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14975 = shape.shape_of %14974 : tensor -> tensor<2xindex> + %14976 = shape.cstr_broadcastable %14975, %14972 : tensor<2xindex>, tensor<2xindex> + %14977 = shape.assuming %14976 -> (tensor) { + %19688 = shape.broadcast %14975, %14972 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14974, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14969, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14978 = stablehlo.dot %14977, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5252 = tensor.dim %14950, %c0 : tensor + %14979 = arith.index_cast %dim_5252 : index to i64 + %from_elements_5253 = tensor.from_elements %14979, %c1_i64 : tensor<2xi64> + %14980 = stablehlo.dynamic_reshape %14950, %from_elements_5253 : (tensor, tensor<2xi64>) -> tensor + %dim_5254 = tensor.dim %14947, %c0 : tensor + %14981 = arith.index_cast %dim_5254 : index to i64 + %from_elements_5255 = tensor.from_elements %14981, %c1_i64 : tensor<2xi64> + %14982 = stablehlo.dynamic_reshape %14947, %from_elements_5255 : (tensor, tensor<2xi64>) -> tensor + %14983 = stablehlo.concatenate %14980, %14982, dim = 1 : (tensor, tensor) -> tensor + %14984 = "stablehlo.gather"(%14915, %14983) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %14985 = shape.shape_of %14978 : tensor -> tensor<2xindex> + %14986 = shape.shape_of %14984 : tensor -> tensor<2xindex> + %14987 = shape.cstr_broadcastable %14985, %14986 : tensor<2xindex>, tensor<2xindex> + %14988 = shape.assuming %14987 -> (tensor) { + %19688 = shape.broadcast %14985, %14986 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %14978, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %14984, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %14989 = shape.shape_of %14988 : tensor -> tensor<2xindex> + %14990 = stablehlo.dynamic_broadcast_in_dim %14988, %14989, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %14991 = stablehlo.dynamic_broadcast_in_dim %213, %14989, dims = [] : (tensor, tensor<2xindex>) -> tensor + %14992 = stablehlo.multiply %14990, %14991 : tensor + %dim_5256 = tensor.dim %14952, %c0 : tensor + %14993 = arith.index_cast %dim_5256 : index to i64 + %dim_5257 = tensor.dim %14988, %c0 : tensor + %14994 = arith.index_cast %dim_5257 : index to i64 + %14995 = arith.maxsi %14993, %14994 : i64 + %14996 = arith.index_cast %14995 : i64 to index + %from_elements_5258 = tensor.from_elements %14996, %c4096 : tensor<2xindex> + %14997 = stablehlo.dynamic_broadcast_in_dim %14952, %from_elements_5258, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5259 = tensor.dim %14997, %c0 : tensor + %14998 = arith.index_cast %dim_5259 : index to i64 + %from_elements_5260 = tensor.from_elements %14998, %c4096_i64 : tensor<2xi64> + %14999 = stablehlo.real_dynamic_slice %14992, %c_22, %from_elements_5260, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5261 = tensor.from_elements %14998, %c4096_i64, %c1_i64 : tensor<3xi64> + %15000 = stablehlo.dynamic_reshape %14997, %from_elements_5261 : (tensor, tensor<3xi64>) -> tensor + %15001 = stablehlo.dynamic_iota %from_elements_5261, dim = 1 : (tensor<3xi64>) -> tensor + %15002 = stablehlo.concatenate %15000, %15001, dim = 2 : (tensor, tensor) -> tensor + %15003 = "stablehlo.scatter"(%14940, %15002, %14999) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15004 = stablehlo.slice %14875 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15005 = stablehlo.reshape %15004 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15006 = stablehlo.custom_call @byteir.non_zero(%15005) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5262 = tensor.dim %15006, %c0 : tensor + %15007 = arith.index_cast %dim_5262 : index to i64 + %from_elements_5263 = tensor.from_elements %15007, %c1_i64 : tensor<2xi64> + %15008 = stablehlo.real_dynamic_slice %15006, %c_22, %from_elements_5263, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5264 = tensor.dim %15008, %c0 : tensor + %15009 = arith.index_cast %dim_5264 : index to i64 + %from_elements_5265 = tensor.from_elements %15009 : tensor<1xi64> + %15010 = stablehlo.dynamic_reshape %15008, %from_elements_5265 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5266 = tensor.from_elements %15007, %c2_i64 : tensor<2xi64> + %15011 = stablehlo.real_dynamic_slice %15006, %c_24, %from_elements_5266, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5267 = tensor.dim %15011, %c0 : tensor + %15012 = arith.index_cast %dim_5267 : index to i64 + %from_elements_5268 = tensor.from_elements %15012 : tensor<1xi64> + %15013 = stablehlo.dynamic_reshape %15011, %from_elements_5268 : (tensor, tensor<1xi64>) -> tensor + %dim_5269 = tensor.dim %15013, %c0 : tensor + %15014 = arith.index_cast %dim_5269 : index to i64 + %from_elements_5270 = tensor.from_elements %15014, %c1_i64 : tensor<2xi64> + %15015 = stablehlo.dynamic_reshape %15013, %from_elements_5270 : (tensor, tensor<2xi64>) -> tensor + %dim_5271 = tensor.dim %15015, %c0 : tensor + %15016 = arith.index_cast %dim_5271 : index to i64 + %from_elements_5272 = tensor.from_elements %c1_i64, %15016, %c4096_i64 : tensor<3xi64> + %15017 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5272, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5273 = tensor.dim %15017, %c1 : tensor<1x?x4096xi64> + %15018 = arith.index_cast %dim_5273 : index to i64 + %from_elements_5274 = tensor.from_elements %c1_i64, %15018, %c4096_i64, %c1_i64 : tensor<4xi64> + %15019 = stablehlo.dynamic_reshape %15017, %from_elements_5274 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15020 = stablehlo.dynamic_broadcast_in_dim %15015, %from_elements_5272, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5275 = tensor.dim %15020, %c1 : tensor<1x?x4096xi64> + %15021 = arith.index_cast %dim_5275 : index to i64 + %from_elements_5276 = tensor.from_elements %c1_i64, %15021, %c4096_i64, %c1_i64 : tensor<4xi64> + %15022 = stablehlo.dynamic_reshape %15020, %from_elements_5276 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15023 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5272, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5277 = tensor.dim %15023, %c1 : tensor<1x?x4096xi64> + %15024 = arith.index_cast %dim_5277 : index to i64 + %from_elements_5278 = tensor.from_elements %c1_i64, %15024, %c4096_i64, %c1_i64 : tensor<4xi64> + %15025 = stablehlo.dynamic_reshape %15023, %from_elements_5278 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15026 = stablehlo.concatenate %15019, %15022, %15025, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15027 = "stablehlo.gather"(%14886, %15026) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15028 = shape.shape_of %15027 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15029 = shape.num_elements %15028 : tensor<3xindex> -> index + %15030 = stablehlo.compute_reshape_shape %15029, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15031 = stablehlo.dynamic_reshape %15027, %15030 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15032 = stablehlo.dot %15031, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15033 = stablehlo.logistic %15032 : tensor + %15034 = shape.shape_of %15033 : tensor -> tensor<2xindex> + %15035 = shape.shape_of %15032 : tensor -> tensor<2xindex> + %15036 = shape.cstr_broadcastable %15034, %15035 : tensor<2xindex>, tensor<2xindex> + %15037 = shape.assuming %15036 -> (tensor) { + %19688 = shape.broadcast %15034, %15035 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15033, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15032, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15038 = shape.shape_of %15037 : tensor -> tensor<2xindex> + %15039 = shape.cstr_broadcastable %15038, %15035 : tensor<2xindex>, tensor<2xindex> + %15040 = shape.assuming %15039 -> (tensor) { + %19688 = shape.broadcast %15038, %15035 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15037, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15032, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15041 = stablehlo.dot %15040, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5279 = tensor.dim %15013, %c0 : tensor + %15042 = arith.index_cast %dim_5279 : index to i64 + %from_elements_5280 = tensor.from_elements %15042, %c1_i64 : tensor<2xi64> + %15043 = stablehlo.dynamic_reshape %15013, %from_elements_5280 : (tensor, tensor<2xi64>) -> tensor + %dim_5281 = tensor.dim %15010, %c0 : tensor + %15044 = arith.index_cast %dim_5281 : index to i64 + %from_elements_5282 = tensor.from_elements %15044, %c1_i64 : tensor<2xi64> + %15045 = stablehlo.dynamic_reshape %15010, %from_elements_5282 : (tensor, tensor<2xi64>) -> tensor + %15046 = stablehlo.concatenate %15043, %15045, dim = 1 : (tensor, tensor) -> tensor + %15047 = "stablehlo.gather"(%14915, %15046) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15048 = shape.shape_of %15041 : tensor -> tensor<2xindex> + %15049 = shape.shape_of %15047 : tensor -> tensor<2xindex> + %15050 = shape.cstr_broadcastable %15048, %15049 : tensor<2xindex>, tensor<2xindex> + %15051 = shape.assuming %15050 -> (tensor) { + %19688 = shape.broadcast %15048, %15049 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15041, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15047, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15052 = shape.shape_of %15051 : tensor -> tensor<2xindex> + %15053 = stablehlo.dynamic_broadcast_in_dim %15051, %15052, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15054 = stablehlo.dynamic_broadcast_in_dim %213, %15052, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15055 = stablehlo.multiply %15053, %15054 : tensor + %dim_5283 = tensor.dim %15015, %c0 : tensor + %15056 = arith.index_cast %dim_5283 : index to i64 + %dim_5284 = tensor.dim %15051, %c0 : tensor + %15057 = arith.index_cast %dim_5284 : index to i64 + %15058 = arith.maxsi %15056, %15057 : i64 + %15059 = arith.index_cast %15058 : i64 to index + %from_elements_5285 = tensor.from_elements %15059, %c4096 : tensor<2xindex> + %15060 = stablehlo.dynamic_broadcast_in_dim %15015, %from_elements_5285, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5286 = tensor.dim %15060, %c0 : tensor + %15061 = arith.index_cast %dim_5286 : index to i64 + %from_elements_5287 = tensor.from_elements %15061, %c4096_i64 : tensor<2xi64> + %15062 = stablehlo.real_dynamic_slice %15055, %c_22, %from_elements_5287, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5288 = tensor.from_elements %15061, %c4096_i64, %c1_i64 : tensor<3xi64> + %15063 = stablehlo.dynamic_reshape %15060, %from_elements_5288 : (tensor, tensor<3xi64>) -> tensor + %15064 = stablehlo.dynamic_iota %from_elements_5288, dim = 1 : (tensor<3xi64>) -> tensor + %15065 = stablehlo.concatenate %15063, %15064, dim = 2 : (tensor, tensor) -> tensor + %15066 = "stablehlo.scatter"(%15003, %15065, %15062) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15067 = stablehlo.slice %14875 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15068 = stablehlo.reshape %15067 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15069 = stablehlo.custom_call @byteir.non_zero(%15068) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5289 = tensor.dim %15069, %c0 : tensor + %15070 = arith.index_cast %dim_5289 : index to i64 + %from_elements_5290 = tensor.from_elements %15070, %c1_i64 : tensor<2xi64> + %15071 = stablehlo.real_dynamic_slice %15069, %c_22, %from_elements_5290, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5291 = tensor.dim %15071, %c0 : tensor + %15072 = arith.index_cast %dim_5291 : index to i64 + %from_elements_5292 = tensor.from_elements %15072 : tensor<1xi64> + %15073 = stablehlo.dynamic_reshape %15071, %from_elements_5292 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5293 = tensor.from_elements %15070, %c2_i64 : tensor<2xi64> + %15074 = stablehlo.real_dynamic_slice %15069, %c_24, %from_elements_5293, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5294 = tensor.dim %15074, %c0 : tensor + %15075 = arith.index_cast %dim_5294 : index to i64 + %from_elements_5295 = tensor.from_elements %15075 : tensor<1xi64> + %15076 = stablehlo.dynamic_reshape %15074, %from_elements_5295 : (tensor, tensor<1xi64>) -> tensor + %dim_5296 = tensor.dim %15076, %c0 : tensor + %15077 = arith.index_cast %dim_5296 : index to i64 + %from_elements_5297 = tensor.from_elements %15077, %c1_i64 : tensor<2xi64> + %15078 = stablehlo.dynamic_reshape %15076, %from_elements_5297 : (tensor, tensor<2xi64>) -> tensor + %dim_5298 = tensor.dim %15078, %c0 : tensor + %15079 = arith.index_cast %dim_5298 : index to i64 + %from_elements_5299 = tensor.from_elements %c1_i64, %15079, %c4096_i64 : tensor<3xi64> + %15080 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5299, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5300 = tensor.dim %15080, %c1 : tensor<1x?x4096xi64> + %15081 = arith.index_cast %dim_5300 : index to i64 + %from_elements_5301 = tensor.from_elements %c1_i64, %15081, %c4096_i64, %c1_i64 : tensor<4xi64> + %15082 = stablehlo.dynamic_reshape %15080, %from_elements_5301 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15083 = stablehlo.dynamic_broadcast_in_dim %15078, %from_elements_5299, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5302 = tensor.dim %15083, %c1 : tensor<1x?x4096xi64> + %15084 = arith.index_cast %dim_5302 : index to i64 + %from_elements_5303 = tensor.from_elements %c1_i64, %15084, %c4096_i64, %c1_i64 : tensor<4xi64> + %15085 = stablehlo.dynamic_reshape %15083, %from_elements_5303 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15086 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5299, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5304 = tensor.dim %15086, %c1 : tensor<1x?x4096xi64> + %15087 = arith.index_cast %dim_5304 : index to i64 + %from_elements_5305 = tensor.from_elements %c1_i64, %15087, %c4096_i64, %c1_i64 : tensor<4xi64> + %15088 = stablehlo.dynamic_reshape %15086, %from_elements_5305 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15089 = stablehlo.concatenate %15082, %15085, %15088, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15090 = "stablehlo.gather"(%14886, %15089) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15091 = shape.shape_of %15090 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15092 = shape.num_elements %15091 : tensor<3xindex> -> index + %15093 = stablehlo.compute_reshape_shape %15092, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15094 = stablehlo.dynamic_reshape %15090, %15093 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15095 = stablehlo.dot %15094, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15096 = stablehlo.logistic %15095 : tensor + %15097 = shape.shape_of %15096 : tensor -> tensor<2xindex> + %15098 = shape.shape_of %15095 : tensor -> tensor<2xindex> + %15099 = shape.cstr_broadcastable %15097, %15098 : tensor<2xindex>, tensor<2xindex> + %15100 = shape.assuming %15099 -> (tensor) { + %19688 = shape.broadcast %15097, %15098 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15096, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15095, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15101 = shape.shape_of %15100 : tensor -> tensor<2xindex> + %15102 = shape.cstr_broadcastable %15101, %15098 : tensor<2xindex>, tensor<2xindex> + %15103 = shape.assuming %15102 -> (tensor) { + %19688 = shape.broadcast %15101, %15098 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15100, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15095, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15104 = stablehlo.dot %15103, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5306 = tensor.dim %15076, %c0 : tensor + %15105 = arith.index_cast %dim_5306 : index to i64 + %from_elements_5307 = tensor.from_elements %15105, %c1_i64 : tensor<2xi64> + %15106 = stablehlo.dynamic_reshape %15076, %from_elements_5307 : (tensor, tensor<2xi64>) -> tensor + %dim_5308 = tensor.dim %15073, %c0 : tensor + %15107 = arith.index_cast %dim_5308 : index to i64 + %from_elements_5309 = tensor.from_elements %15107, %c1_i64 : tensor<2xi64> + %15108 = stablehlo.dynamic_reshape %15073, %from_elements_5309 : (tensor, tensor<2xi64>) -> tensor + %15109 = stablehlo.concatenate %15106, %15108, dim = 1 : (tensor, tensor) -> tensor + %15110 = "stablehlo.gather"(%14915, %15109) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15111 = shape.shape_of %15104 : tensor -> tensor<2xindex> + %15112 = shape.shape_of %15110 : tensor -> tensor<2xindex> + %15113 = shape.cstr_broadcastable %15111, %15112 : tensor<2xindex>, tensor<2xindex> + %15114 = shape.assuming %15113 -> (tensor) { + %19688 = shape.broadcast %15111, %15112 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15104, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15110, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15115 = shape.shape_of %15114 : tensor -> tensor<2xindex> + %15116 = stablehlo.dynamic_broadcast_in_dim %15114, %15115, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15117 = stablehlo.dynamic_broadcast_in_dim %213, %15115, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15118 = stablehlo.multiply %15116, %15117 : tensor + %dim_5310 = tensor.dim %15078, %c0 : tensor + %15119 = arith.index_cast %dim_5310 : index to i64 + %dim_5311 = tensor.dim %15114, %c0 : tensor + %15120 = arith.index_cast %dim_5311 : index to i64 + %15121 = arith.maxsi %15119, %15120 : i64 + %15122 = arith.index_cast %15121 : i64 to index + %from_elements_5312 = tensor.from_elements %15122, %c4096 : tensor<2xindex> + %15123 = stablehlo.dynamic_broadcast_in_dim %15078, %from_elements_5312, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5313 = tensor.dim %15123, %c0 : tensor + %15124 = arith.index_cast %dim_5313 : index to i64 + %from_elements_5314 = tensor.from_elements %15124, %c4096_i64 : tensor<2xi64> + %15125 = stablehlo.real_dynamic_slice %15118, %c_22, %from_elements_5314, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5315 = tensor.from_elements %15124, %c4096_i64, %c1_i64 : tensor<3xi64> + %15126 = stablehlo.dynamic_reshape %15123, %from_elements_5315 : (tensor, tensor<3xi64>) -> tensor + %15127 = stablehlo.dynamic_iota %from_elements_5315, dim = 1 : (tensor<3xi64>) -> tensor + %15128 = stablehlo.concatenate %15126, %15127, dim = 2 : (tensor, tensor) -> tensor + %15129 = "stablehlo.scatter"(%15066, %15128, %15125) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15130 = stablehlo.slice %14875 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15131 = stablehlo.reshape %15130 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15132 = stablehlo.custom_call @byteir.non_zero(%15131) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5316 = tensor.dim %15132, %c0 : tensor + %15133 = arith.index_cast %dim_5316 : index to i64 + %from_elements_5317 = tensor.from_elements %15133, %c1_i64 : tensor<2xi64> + %15134 = stablehlo.real_dynamic_slice %15132, %c_22, %from_elements_5317, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5318 = tensor.dim %15134, %c0 : tensor + %15135 = arith.index_cast %dim_5318 : index to i64 + %from_elements_5319 = tensor.from_elements %15135 : tensor<1xi64> + %15136 = stablehlo.dynamic_reshape %15134, %from_elements_5319 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5320 = tensor.from_elements %15133, %c2_i64 : tensor<2xi64> + %15137 = stablehlo.real_dynamic_slice %15132, %c_24, %from_elements_5320, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5321 = tensor.dim %15137, %c0 : tensor + %15138 = arith.index_cast %dim_5321 : index to i64 + %from_elements_5322 = tensor.from_elements %15138 : tensor<1xi64> + %15139 = stablehlo.dynamic_reshape %15137, %from_elements_5322 : (tensor, tensor<1xi64>) -> tensor + %dim_5323 = tensor.dim %15139, %c0 : tensor + %15140 = arith.index_cast %dim_5323 : index to i64 + %from_elements_5324 = tensor.from_elements %15140, %c1_i64 : tensor<2xi64> + %15141 = stablehlo.dynamic_reshape %15139, %from_elements_5324 : (tensor, tensor<2xi64>) -> tensor + %dim_5325 = tensor.dim %15141, %c0 : tensor + %15142 = arith.index_cast %dim_5325 : index to i64 + %from_elements_5326 = tensor.from_elements %c1_i64, %15142, %c4096_i64 : tensor<3xi64> + %15143 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5326, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5327 = tensor.dim %15143, %c1 : tensor<1x?x4096xi64> + %15144 = arith.index_cast %dim_5327 : index to i64 + %from_elements_5328 = tensor.from_elements %c1_i64, %15144, %c4096_i64, %c1_i64 : tensor<4xi64> + %15145 = stablehlo.dynamic_reshape %15143, %from_elements_5328 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15146 = stablehlo.dynamic_broadcast_in_dim %15141, %from_elements_5326, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5329 = tensor.dim %15146, %c1 : tensor<1x?x4096xi64> + %15147 = arith.index_cast %dim_5329 : index to i64 + %from_elements_5330 = tensor.from_elements %c1_i64, %15147, %c4096_i64, %c1_i64 : tensor<4xi64> + %15148 = stablehlo.dynamic_reshape %15146, %from_elements_5330 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15149 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5326, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5331 = tensor.dim %15149, %c1 : tensor<1x?x4096xi64> + %15150 = arith.index_cast %dim_5331 : index to i64 + %from_elements_5332 = tensor.from_elements %c1_i64, %15150, %c4096_i64, %c1_i64 : tensor<4xi64> + %15151 = stablehlo.dynamic_reshape %15149, %from_elements_5332 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15152 = stablehlo.concatenate %15145, %15148, %15151, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15153 = "stablehlo.gather"(%14886, %15152) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15154 = shape.shape_of %15153 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15155 = shape.num_elements %15154 : tensor<3xindex> -> index + %15156 = stablehlo.compute_reshape_shape %15155, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15157 = stablehlo.dynamic_reshape %15153, %15156 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15158 = stablehlo.dot %15157, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15159 = stablehlo.logistic %15158 : tensor + %15160 = shape.shape_of %15159 : tensor -> tensor<2xindex> + %15161 = shape.shape_of %15158 : tensor -> tensor<2xindex> + %15162 = shape.cstr_broadcastable %15160, %15161 : tensor<2xindex>, tensor<2xindex> + %15163 = shape.assuming %15162 -> (tensor) { + %19688 = shape.broadcast %15160, %15161 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15159, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15158, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15164 = shape.shape_of %15163 : tensor -> tensor<2xindex> + %15165 = shape.cstr_broadcastable %15164, %15161 : tensor<2xindex>, tensor<2xindex> + %15166 = shape.assuming %15165 -> (tensor) { + %19688 = shape.broadcast %15164, %15161 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15163, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15158, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15167 = stablehlo.dot %15166, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5333 = tensor.dim %15139, %c0 : tensor + %15168 = arith.index_cast %dim_5333 : index to i64 + %from_elements_5334 = tensor.from_elements %15168, %c1_i64 : tensor<2xi64> + %15169 = stablehlo.dynamic_reshape %15139, %from_elements_5334 : (tensor, tensor<2xi64>) -> tensor + %dim_5335 = tensor.dim %15136, %c0 : tensor + %15170 = arith.index_cast %dim_5335 : index to i64 + %from_elements_5336 = tensor.from_elements %15170, %c1_i64 : tensor<2xi64> + %15171 = stablehlo.dynamic_reshape %15136, %from_elements_5336 : (tensor, tensor<2xi64>) -> tensor + %15172 = stablehlo.concatenate %15169, %15171, dim = 1 : (tensor, tensor) -> tensor + %15173 = "stablehlo.gather"(%14915, %15172) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15174 = shape.shape_of %15167 : tensor -> tensor<2xindex> + %15175 = shape.shape_of %15173 : tensor -> tensor<2xindex> + %15176 = shape.cstr_broadcastable %15174, %15175 : tensor<2xindex>, tensor<2xindex> + %15177 = shape.assuming %15176 -> (tensor) { + %19688 = shape.broadcast %15174, %15175 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15167, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15173, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15178 = shape.shape_of %15177 : tensor -> tensor<2xindex> + %15179 = stablehlo.dynamic_broadcast_in_dim %15177, %15178, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15180 = stablehlo.dynamic_broadcast_in_dim %213, %15178, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15181 = stablehlo.multiply %15179, %15180 : tensor + %dim_5337 = tensor.dim %15141, %c0 : tensor + %15182 = arith.index_cast %dim_5337 : index to i64 + %dim_5338 = tensor.dim %15177, %c0 : tensor + %15183 = arith.index_cast %dim_5338 : index to i64 + %15184 = arith.maxsi %15182, %15183 : i64 + %15185 = arith.index_cast %15184 : i64 to index + %from_elements_5339 = tensor.from_elements %15185, %c4096 : tensor<2xindex> + %15186 = stablehlo.dynamic_broadcast_in_dim %15141, %from_elements_5339, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5340 = tensor.dim %15186, %c0 : tensor + %15187 = arith.index_cast %dim_5340 : index to i64 + %from_elements_5341 = tensor.from_elements %15187, %c4096_i64 : tensor<2xi64> + %15188 = stablehlo.real_dynamic_slice %15181, %c_22, %from_elements_5341, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5342 = tensor.from_elements %15187, %c4096_i64, %c1_i64 : tensor<3xi64> + %15189 = stablehlo.dynamic_reshape %15186, %from_elements_5342 : (tensor, tensor<3xi64>) -> tensor + %15190 = stablehlo.dynamic_iota %from_elements_5342, dim = 1 : (tensor<3xi64>) -> tensor + %15191 = stablehlo.concatenate %15189, %15190, dim = 2 : (tensor, tensor) -> tensor + %15192 = "stablehlo.scatter"(%15129, %15191, %15188) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15193 = stablehlo.slice %14875 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15194 = stablehlo.reshape %15193 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15195 = stablehlo.custom_call @byteir.non_zero(%15194) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5343 = tensor.dim %15195, %c0 : tensor + %15196 = arith.index_cast %dim_5343 : index to i64 + %from_elements_5344 = tensor.from_elements %15196, %c1_i64 : tensor<2xi64> + %15197 = stablehlo.real_dynamic_slice %15195, %c_22, %from_elements_5344, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5345 = tensor.dim %15197, %c0 : tensor + %15198 = arith.index_cast %dim_5345 : index to i64 + %from_elements_5346 = tensor.from_elements %15198 : tensor<1xi64> + %15199 = stablehlo.dynamic_reshape %15197, %from_elements_5346 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5347 = tensor.from_elements %15196, %c2_i64 : tensor<2xi64> + %15200 = stablehlo.real_dynamic_slice %15195, %c_24, %from_elements_5347, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5348 = tensor.dim %15200, %c0 : tensor + %15201 = arith.index_cast %dim_5348 : index to i64 + %from_elements_5349 = tensor.from_elements %15201 : tensor<1xi64> + %15202 = stablehlo.dynamic_reshape %15200, %from_elements_5349 : (tensor, tensor<1xi64>) -> tensor + %dim_5350 = tensor.dim %15202, %c0 : tensor + %15203 = arith.index_cast %dim_5350 : index to i64 + %from_elements_5351 = tensor.from_elements %15203, %c1_i64 : tensor<2xi64> + %15204 = stablehlo.dynamic_reshape %15202, %from_elements_5351 : (tensor, tensor<2xi64>) -> tensor + %dim_5352 = tensor.dim %15204, %c0 : tensor + %15205 = arith.index_cast %dim_5352 : index to i64 + %from_elements_5353 = tensor.from_elements %c1_i64, %15205, %c4096_i64 : tensor<3xi64> + %15206 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5353, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5354 = tensor.dim %15206, %c1 : tensor<1x?x4096xi64> + %15207 = arith.index_cast %dim_5354 : index to i64 + %from_elements_5355 = tensor.from_elements %c1_i64, %15207, %c4096_i64, %c1_i64 : tensor<4xi64> + %15208 = stablehlo.dynamic_reshape %15206, %from_elements_5355 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15209 = stablehlo.dynamic_broadcast_in_dim %15204, %from_elements_5353, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5356 = tensor.dim %15209, %c1 : tensor<1x?x4096xi64> + %15210 = arith.index_cast %dim_5356 : index to i64 + %from_elements_5357 = tensor.from_elements %c1_i64, %15210, %c4096_i64, %c1_i64 : tensor<4xi64> + %15211 = stablehlo.dynamic_reshape %15209, %from_elements_5357 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15212 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5353, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5358 = tensor.dim %15212, %c1 : tensor<1x?x4096xi64> + %15213 = arith.index_cast %dim_5358 : index to i64 + %from_elements_5359 = tensor.from_elements %c1_i64, %15213, %c4096_i64, %c1_i64 : tensor<4xi64> + %15214 = stablehlo.dynamic_reshape %15212, %from_elements_5359 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15215 = stablehlo.concatenate %15208, %15211, %15214, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15216 = "stablehlo.gather"(%14886, %15215) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15217 = shape.shape_of %15216 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15218 = shape.num_elements %15217 : tensor<3xindex> -> index + %15219 = stablehlo.compute_reshape_shape %15218, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15220 = stablehlo.dynamic_reshape %15216, %15219 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15221 = stablehlo.dot %15220, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15222 = stablehlo.logistic %15221 : tensor + %15223 = shape.shape_of %15222 : tensor -> tensor<2xindex> + %15224 = shape.shape_of %15221 : tensor -> tensor<2xindex> + %15225 = shape.cstr_broadcastable %15223, %15224 : tensor<2xindex>, tensor<2xindex> + %15226 = shape.assuming %15225 -> (tensor) { + %19688 = shape.broadcast %15223, %15224 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15222, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15221, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15227 = shape.shape_of %15226 : tensor -> tensor<2xindex> + %15228 = shape.cstr_broadcastable %15227, %15224 : tensor<2xindex>, tensor<2xindex> + %15229 = shape.assuming %15228 -> (tensor) { + %19688 = shape.broadcast %15227, %15224 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15226, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15221, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15230 = stablehlo.dot %15229, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5360 = tensor.dim %15202, %c0 : tensor + %15231 = arith.index_cast %dim_5360 : index to i64 + %from_elements_5361 = tensor.from_elements %15231, %c1_i64 : tensor<2xi64> + %15232 = stablehlo.dynamic_reshape %15202, %from_elements_5361 : (tensor, tensor<2xi64>) -> tensor + %dim_5362 = tensor.dim %15199, %c0 : tensor + %15233 = arith.index_cast %dim_5362 : index to i64 + %from_elements_5363 = tensor.from_elements %15233, %c1_i64 : tensor<2xi64> + %15234 = stablehlo.dynamic_reshape %15199, %from_elements_5363 : (tensor, tensor<2xi64>) -> tensor + %15235 = stablehlo.concatenate %15232, %15234, dim = 1 : (tensor, tensor) -> tensor + %15236 = "stablehlo.gather"(%14915, %15235) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15237 = shape.shape_of %15230 : tensor -> tensor<2xindex> + %15238 = shape.shape_of %15236 : tensor -> tensor<2xindex> + %15239 = shape.cstr_broadcastable %15237, %15238 : tensor<2xindex>, tensor<2xindex> + %15240 = shape.assuming %15239 -> (tensor) { + %19688 = shape.broadcast %15237, %15238 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15230, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15236, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15241 = shape.shape_of %15240 : tensor -> tensor<2xindex> + %15242 = stablehlo.dynamic_broadcast_in_dim %15240, %15241, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15243 = stablehlo.dynamic_broadcast_in_dim %213, %15241, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15244 = stablehlo.multiply %15242, %15243 : tensor + %dim_5364 = tensor.dim %15204, %c0 : tensor + %15245 = arith.index_cast %dim_5364 : index to i64 + %dim_5365 = tensor.dim %15240, %c0 : tensor + %15246 = arith.index_cast %dim_5365 : index to i64 + %15247 = arith.maxsi %15245, %15246 : i64 + %15248 = arith.index_cast %15247 : i64 to index + %from_elements_5366 = tensor.from_elements %15248, %c4096 : tensor<2xindex> + %15249 = stablehlo.dynamic_broadcast_in_dim %15204, %from_elements_5366, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5367 = tensor.dim %15249, %c0 : tensor + %15250 = arith.index_cast %dim_5367 : index to i64 + %from_elements_5368 = tensor.from_elements %15250, %c4096_i64 : tensor<2xi64> + %15251 = stablehlo.real_dynamic_slice %15244, %c_22, %from_elements_5368, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5369 = tensor.from_elements %15250, %c4096_i64, %c1_i64 : tensor<3xi64> + %15252 = stablehlo.dynamic_reshape %15249, %from_elements_5369 : (tensor, tensor<3xi64>) -> tensor + %15253 = stablehlo.dynamic_iota %from_elements_5369, dim = 1 : (tensor<3xi64>) -> tensor + %15254 = stablehlo.concatenate %15252, %15253, dim = 2 : (tensor, tensor) -> tensor + %15255 = "stablehlo.scatter"(%15192, %15254, %15251) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15256 = stablehlo.slice %14875 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15257 = stablehlo.reshape %15256 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15258 = stablehlo.custom_call @byteir.non_zero(%15257) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5370 = tensor.dim %15258, %c0 : tensor + %15259 = arith.index_cast %dim_5370 : index to i64 + %from_elements_5371 = tensor.from_elements %15259, %c1_i64 : tensor<2xi64> + %15260 = stablehlo.real_dynamic_slice %15258, %c_22, %from_elements_5371, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5372 = tensor.dim %15260, %c0 : tensor + %15261 = arith.index_cast %dim_5372 : index to i64 + %from_elements_5373 = tensor.from_elements %15261 : tensor<1xi64> + %15262 = stablehlo.dynamic_reshape %15260, %from_elements_5373 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5374 = tensor.from_elements %15259, %c2_i64 : tensor<2xi64> + %15263 = stablehlo.real_dynamic_slice %15258, %c_24, %from_elements_5374, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5375 = tensor.dim %15263, %c0 : tensor + %15264 = arith.index_cast %dim_5375 : index to i64 + %from_elements_5376 = tensor.from_elements %15264 : tensor<1xi64> + %15265 = stablehlo.dynamic_reshape %15263, %from_elements_5376 : (tensor, tensor<1xi64>) -> tensor + %dim_5377 = tensor.dim %15265, %c0 : tensor + %15266 = arith.index_cast %dim_5377 : index to i64 + %from_elements_5378 = tensor.from_elements %15266, %c1_i64 : tensor<2xi64> + %15267 = stablehlo.dynamic_reshape %15265, %from_elements_5378 : (tensor, tensor<2xi64>) -> tensor + %dim_5379 = tensor.dim %15267, %c0 : tensor + %15268 = arith.index_cast %dim_5379 : index to i64 + %from_elements_5380 = tensor.from_elements %c1_i64, %15268, %c4096_i64 : tensor<3xi64> + %15269 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5380, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5381 = tensor.dim %15269, %c1 : tensor<1x?x4096xi64> + %15270 = arith.index_cast %dim_5381 : index to i64 + %from_elements_5382 = tensor.from_elements %c1_i64, %15270, %c4096_i64, %c1_i64 : tensor<4xi64> + %15271 = stablehlo.dynamic_reshape %15269, %from_elements_5382 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15272 = stablehlo.dynamic_broadcast_in_dim %15267, %from_elements_5380, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5383 = tensor.dim %15272, %c1 : tensor<1x?x4096xi64> + %15273 = arith.index_cast %dim_5383 : index to i64 + %from_elements_5384 = tensor.from_elements %c1_i64, %15273, %c4096_i64, %c1_i64 : tensor<4xi64> + %15274 = stablehlo.dynamic_reshape %15272, %from_elements_5384 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15275 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5380, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5385 = tensor.dim %15275, %c1 : tensor<1x?x4096xi64> + %15276 = arith.index_cast %dim_5385 : index to i64 + %from_elements_5386 = tensor.from_elements %c1_i64, %15276, %c4096_i64, %c1_i64 : tensor<4xi64> + %15277 = stablehlo.dynamic_reshape %15275, %from_elements_5386 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15278 = stablehlo.concatenate %15271, %15274, %15277, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15279 = "stablehlo.gather"(%14886, %15278) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15280 = shape.shape_of %15279 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15281 = shape.num_elements %15280 : tensor<3xindex> -> index + %15282 = stablehlo.compute_reshape_shape %15281, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15283 = stablehlo.dynamic_reshape %15279, %15282 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15284 = stablehlo.dot %15283, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15285 = stablehlo.logistic %15284 : tensor + %15286 = shape.shape_of %15285 : tensor -> tensor<2xindex> + %15287 = shape.shape_of %15284 : tensor -> tensor<2xindex> + %15288 = shape.cstr_broadcastable %15286, %15287 : tensor<2xindex>, tensor<2xindex> + %15289 = shape.assuming %15288 -> (tensor) { + %19688 = shape.broadcast %15286, %15287 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15285, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15284, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15290 = shape.shape_of %15289 : tensor -> tensor<2xindex> + %15291 = shape.cstr_broadcastable %15290, %15287 : tensor<2xindex>, tensor<2xindex> + %15292 = shape.assuming %15291 -> (tensor) { + %19688 = shape.broadcast %15290, %15287 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15289, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15284, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15293 = stablehlo.dot %15292, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5387 = tensor.dim %15265, %c0 : tensor + %15294 = arith.index_cast %dim_5387 : index to i64 + %from_elements_5388 = tensor.from_elements %15294, %c1_i64 : tensor<2xi64> + %15295 = stablehlo.dynamic_reshape %15265, %from_elements_5388 : (tensor, tensor<2xi64>) -> tensor + %dim_5389 = tensor.dim %15262, %c0 : tensor + %15296 = arith.index_cast %dim_5389 : index to i64 + %from_elements_5390 = tensor.from_elements %15296, %c1_i64 : tensor<2xi64> + %15297 = stablehlo.dynamic_reshape %15262, %from_elements_5390 : (tensor, tensor<2xi64>) -> tensor + %15298 = stablehlo.concatenate %15295, %15297, dim = 1 : (tensor, tensor) -> tensor + %15299 = "stablehlo.gather"(%14915, %15298) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15300 = shape.shape_of %15293 : tensor -> tensor<2xindex> + %15301 = shape.shape_of %15299 : tensor -> tensor<2xindex> + %15302 = shape.cstr_broadcastable %15300, %15301 : tensor<2xindex>, tensor<2xindex> + %15303 = shape.assuming %15302 -> (tensor) { + %19688 = shape.broadcast %15300, %15301 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15293, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15299, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15304 = shape.shape_of %15303 : tensor -> tensor<2xindex> + %15305 = stablehlo.dynamic_broadcast_in_dim %15303, %15304, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15306 = stablehlo.dynamic_broadcast_in_dim %213, %15304, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15307 = stablehlo.multiply %15305, %15306 : tensor + %dim_5391 = tensor.dim %15267, %c0 : tensor + %15308 = arith.index_cast %dim_5391 : index to i64 + %dim_5392 = tensor.dim %15303, %c0 : tensor + %15309 = arith.index_cast %dim_5392 : index to i64 + %15310 = arith.maxsi %15308, %15309 : i64 + %15311 = arith.index_cast %15310 : i64 to index + %from_elements_5393 = tensor.from_elements %15311, %c4096 : tensor<2xindex> + %15312 = stablehlo.dynamic_broadcast_in_dim %15267, %from_elements_5393, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5394 = tensor.dim %15312, %c0 : tensor + %15313 = arith.index_cast %dim_5394 : index to i64 + %from_elements_5395 = tensor.from_elements %15313, %c4096_i64 : tensor<2xi64> + %15314 = stablehlo.real_dynamic_slice %15307, %c_22, %from_elements_5395, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5396 = tensor.from_elements %15313, %c4096_i64, %c1_i64 : tensor<3xi64> + %15315 = stablehlo.dynamic_reshape %15312, %from_elements_5396 : (tensor, tensor<3xi64>) -> tensor + %15316 = stablehlo.dynamic_iota %from_elements_5396, dim = 1 : (tensor<3xi64>) -> tensor + %15317 = stablehlo.concatenate %15315, %15316, dim = 2 : (tensor, tensor) -> tensor + %15318 = "stablehlo.scatter"(%15255, %15317, %15314) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15319 = stablehlo.slice %14875 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15320 = stablehlo.reshape %15319 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15321 = stablehlo.custom_call @byteir.non_zero(%15320) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5397 = tensor.dim %15321, %c0 : tensor + %15322 = arith.index_cast %dim_5397 : index to i64 + %from_elements_5398 = tensor.from_elements %15322, %c1_i64 : tensor<2xi64> + %15323 = stablehlo.real_dynamic_slice %15321, %c_22, %from_elements_5398, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5399 = tensor.dim %15323, %c0 : tensor + %15324 = arith.index_cast %dim_5399 : index to i64 + %from_elements_5400 = tensor.from_elements %15324 : tensor<1xi64> + %15325 = stablehlo.dynamic_reshape %15323, %from_elements_5400 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5401 = tensor.from_elements %15322, %c2_i64 : tensor<2xi64> + %15326 = stablehlo.real_dynamic_slice %15321, %c_24, %from_elements_5401, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5402 = tensor.dim %15326, %c0 : tensor + %15327 = arith.index_cast %dim_5402 : index to i64 + %from_elements_5403 = tensor.from_elements %15327 : tensor<1xi64> + %15328 = stablehlo.dynamic_reshape %15326, %from_elements_5403 : (tensor, tensor<1xi64>) -> tensor + %dim_5404 = tensor.dim %15328, %c0 : tensor + %15329 = arith.index_cast %dim_5404 : index to i64 + %from_elements_5405 = tensor.from_elements %15329, %c1_i64 : tensor<2xi64> + %15330 = stablehlo.dynamic_reshape %15328, %from_elements_5405 : (tensor, tensor<2xi64>) -> tensor + %dim_5406 = tensor.dim %15330, %c0 : tensor + %15331 = arith.index_cast %dim_5406 : index to i64 + %from_elements_5407 = tensor.from_elements %c1_i64, %15331, %c4096_i64 : tensor<3xi64> + %15332 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5407, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5408 = tensor.dim %15332, %c1 : tensor<1x?x4096xi64> + %15333 = arith.index_cast %dim_5408 : index to i64 + %from_elements_5409 = tensor.from_elements %c1_i64, %15333, %c4096_i64, %c1_i64 : tensor<4xi64> + %15334 = stablehlo.dynamic_reshape %15332, %from_elements_5409 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15335 = stablehlo.dynamic_broadcast_in_dim %15330, %from_elements_5407, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5410 = tensor.dim %15335, %c1 : tensor<1x?x4096xi64> + %15336 = arith.index_cast %dim_5410 : index to i64 + %from_elements_5411 = tensor.from_elements %c1_i64, %15336, %c4096_i64, %c1_i64 : tensor<4xi64> + %15337 = stablehlo.dynamic_reshape %15335, %from_elements_5411 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15338 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5407, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5412 = tensor.dim %15338, %c1 : tensor<1x?x4096xi64> + %15339 = arith.index_cast %dim_5412 : index to i64 + %from_elements_5413 = tensor.from_elements %c1_i64, %15339, %c4096_i64, %c1_i64 : tensor<4xi64> + %15340 = stablehlo.dynamic_reshape %15338, %from_elements_5413 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15341 = stablehlo.concatenate %15334, %15337, %15340, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15342 = "stablehlo.gather"(%14886, %15341) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15343 = shape.shape_of %15342 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15344 = shape.num_elements %15343 : tensor<3xindex> -> index + %15345 = stablehlo.compute_reshape_shape %15344, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15346 = stablehlo.dynamic_reshape %15342, %15345 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15347 = stablehlo.dot %15346, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15348 = stablehlo.logistic %15347 : tensor + %15349 = shape.shape_of %15348 : tensor -> tensor<2xindex> + %15350 = shape.shape_of %15347 : tensor -> tensor<2xindex> + %15351 = shape.cstr_broadcastable %15349, %15350 : tensor<2xindex>, tensor<2xindex> + %15352 = shape.assuming %15351 -> (tensor) { + %19688 = shape.broadcast %15349, %15350 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15348, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15347, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15353 = shape.shape_of %15352 : tensor -> tensor<2xindex> + %15354 = shape.cstr_broadcastable %15353, %15350 : tensor<2xindex>, tensor<2xindex> + %15355 = shape.assuming %15354 -> (tensor) { + %19688 = shape.broadcast %15353, %15350 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15352, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15347, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15356 = stablehlo.dot %15355, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5414 = tensor.dim %15328, %c0 : tensor + %15357 = arith.index_cast %dim_5414 : index to i64 + %from_elements_5415 = tensor.from_elements %15357, %c1_i64 : tensor<2xi64> + %15358 = stablehlo.dynamic_reshape %15328, %from_elements_5415 : (tensor, tensor<2xi64>) -> tensor + %dim_5416 = tensor.dim %15325, %c0 : tensor + %15359 = arith.index_cast %dim_5416 : index to i64 + %from_elements_5417 = tensor.from_elements %15359, %c1_i64 : tensor<2xi64> + %15360 = stablehlo.dynamic_reshape %15325, %from_elements_5417 : (tensor, tensor<2xi64>) -> tensor + %15361 = stablehlo.concatenate %15358, %15360, dim = 1 : (tensor, tensor) -> tensor + %15362 = "stablehlo.gather"(%14915, %15361) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15363 = shape.shape_of %15356 : tensor -> tensor<2xindex> + %15364 = shape.shape_of %15362 : tensor -> tensor<2xindex> + %15365 = shape.cstr_broadcastable %15363, %15364 : tensor<2xindex>, tensor<2xindex> + %15366 = shape.assuming %15365 -> (tensor) { + %19688 = shape.broadcast %15363, %15364 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15356, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15362, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15367 = shape.shape_of %15366 : tensor -> tensor<2xindex> + %15368 = stablehlo.dynamic_broadcast_in_dim %15366, %15367, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15369 = stablehlo.dynamic_broadcast_in_dim %213, %15367, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15370 = stablehlo.multiply %15368, %15369 : tensor + %dim_5418 = tensor.dim %15330, %c0 : tensor + %15371 = arith.index_cast %dim_5418 : index to i64 + %dim_5419 = tensor.dim %15366, %c0 : tensor + %15372 = arith.index_cast %dim_5419 : index to i64 + %15373 = arith.maxsi %15371, %15372 : i64 + %15374 = arith.index_cast %15373 : i64 to index + %from_elements_5420 = tensor.from_elements %15374, %c4096 : tensor<2xindex> + %15375 = stablehlo.dynamic_broadcast_in_dim %15330, %from_elements_5420, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5421 = tensor.dim %15375, %c0 : tensor + %15376 = arith.index_cast %dim_5421 : index to i64 + %from_elements_5422 = tensor.from_elements %15376, %c4096_i64 : tensor<2xi64> + %15377 = stablehlo.real_dynamic_slice %15370, %c_22, %from_elements_5422, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5423 = tensor.from_elements %15376, %c4096_i64, %c1_i64 : tensor<3xi64> + %15378 = stablehlo.dynamic_reshape %15375, %from_elements_5423 : (tensor, tensor<3xi64>) -> tensor + %15379 = stablehlo.dynamic_iota %from_elements_5423, dim = 1 : (tensor<3xi64>) -> tensor + %15380 = stablehlo.concatenate %15378, %15379, dim = 2 : (tensor, tensor) -> tensor + %15381 = "stablehlo.scatter"(%15318, %15380, %15377) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15382 = stablehlo.reshape %15381 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %15383 = stablehlo.add %14848, %15382 : tensor<3x1x4096xf32> + %15384 = stablehlo.broadcast_in_dim %15383, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %15385 = stablehlo.power %15384, %15 : tensor<3x1x4096xf32> + %15386 = stablehlo.reduce(%15385 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %15387 = stablehlo.reshape %15386 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %15388 = stablehlo.broadcast_in_dim %15387, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %15389 = stablehlo.divide %15388, %21 : tensor<3x1x1xf32> + %15390 = stablehlo.broadcast_in_dim %15389, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %15391 = stablehlo.add %15390, %25 : tensor<3x1x1xf32> + %15392 = stablehlo.rsqrt %15391 : tensor<3x1x1xf32> + %15393 = stablehlo.broadcast_in_dim %15392, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %15394 = stablehlo.multiply %15384, %15393 : tensor<3x1x4096xf32> + %15395 = stablehlo.broadcast_in_dim %15394, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %15396 = stablehlo.multiply %15395, %31 : tensor<3x1x4096xf32> + %15397 = stablehlo.reshape %15396 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %15398 = stablehlo.dot %15397, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %15399 = stablehlo.reshape %15398 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %15400 = stablehlo.dot %15397, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %15401 = stablehlo.reshape %15400 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %15402 = stablehlo.reshape %15399 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %15403 = stablehlo.transpose %15402, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %15404 = stablehlo.reshape %15401 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %15405 = stablehlo.transpose %15404, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %15406 = stablehlo.slice %arg50 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %15407 = stablehlo.slice %arg51 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %15408 = "stablehlo.gather"(%15406, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %15409 = stablehlo.reshape %15408 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %15410 = "stablehlo.gather"(%15407, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %15411 = stablehlo.reshape %15410 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %15412 = stablehlo.broadcast_in_dim %15403, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %15413 = stablehlo.broadcast_in_dim %15409, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %15414 = stablehlo.multiply %15412, %15413 : tensor<3x32x1x128xf32> + %15415 = stablehlo.slice %15403 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %15416 = stablehlo.slice %15403 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %15417 = stablehlo.negate %15416 : tensor<3x32x1x64xf32> + %15418 = stablehlo.concatenate %15417, %15415, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %15419 = stablehlo.broadcast_in_dim %15418, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %15420 = stablehlo.broadcast_in_dim %15411, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %15421 = stablehlo.multiply %15419, %15420 : tensor<3x32x1x128xf32> + %15422 = stablehlo.add %15414, %15421 : tensor<3x32x1x128xf32> + %15423 = stablehlo.broadcast_in_dim %15405, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %15424 = stablehlo.broadcast_in_dim %15409, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %15425 = stablehlo.multiply %15423, %15424 : tensor<3x8x1x128xf32> + %15426 = stablehlo.slice %15405 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %15427 = stablehlo.slice %15405 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %15428 = stablehlo.negate %15427 : tensor<3x8x1x64xf32> + %15429 = stablehlo.concatenate %15428, %15426, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %15430 = stablehlo.broadcast_in_dim %15429, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %15431 = stablehlo.broadcast_in_dim %15411, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %15432 = stablehlo.multiply %15430, %15431 : tensor<3x8x1x128xf32> + %15433 = stablehlo.add %15425, %15432 : tensor<3x8x1x128xf32> + %15434 = stablehlo.concatenate %arg115, %15433, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %15435 = stablehlo.concatenate %arg116, %15405, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %15436 = stablehlo.reshape %15434 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %15437 = stablehlo.broadcast_in_dim %15436, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %15438 = stablehlo.reshape %15437 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %15439 = stablehlo.reshape %15435 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %15440 = stablehlo.broadcast_in_dim %15439, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %15441 = stablehlo.reshape %15440 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %15442 = stablehlo.transpose %15438, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %15443 = stablehlo.reshape %15422 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %15444 = stablehlo.reshape %15442 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %15445 = stablehlo.broadcast_in_dim %15444, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %15446 = stablehlo.dot_general %15443, %15445, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %15447 = stablehlo.reshape %15446 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %15448 = stablehlo.broadcast_in_dim %15447, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %15449 = stablehlo.divide %15448, %89 : tensor<3x32x1x8xf32> + %15450 = stablehlo.custom_call @byteir.softmax(%15449) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %15451 = stablehlo.reshape %15450 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %15452 = stablehlo.reshape %15441 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %15453 = stablehlo.broadcast_in_dim %15452, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %15454 = stablehlo.dot_general %15451, %15453, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %15455 = stablehlo.reshape %15454 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %15456 = stablehlo.transpose %15455, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %15457 = stablehlo.reshape %15456 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %15458 = stablehlo.reshape %15457 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %15459 = stablehlo.dot %15458, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %15460 = stablehlo.reshape %15459 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %15461 = stablehlo.add %15383, %15460 : tensor<3x1x4096xf32> + %15462 = stablehlo.broadcast_in_dim %15461, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %15463 = stablehlo.power %15462, %15 : tensor<3x1x4096xf32> + %15464 = stablehlo.reduce(%15463 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %15465 = stablehlo.reshape %15464 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %15466 = stablehlo.broadcast_in_dim %15465, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %15467 = stablehlo.divide %15466, %21 : tensor<3x1x1xf32> + %15468 = stablehlo.broadcast_in_dim %15467, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %15469 = stablehlo.add %15468, %25 : tensor<3x1x1xf32> + %15470 = stablehlo.rsqrt %15469 : tensor<3x1x1xf32> + %15471 = stablehlo.broadcast_in_dim %15470, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %15472 = stablehlo.multiply %15462, %15471 : tensor<3x1x4096xf32> + %15473 = stablehlo.broadcast_in_dim %15472, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %15474 = stablehlo.multiply %15473, %31 : tensor<3x1x4096xf32> + %15475 = stablehlo.reshape %15474 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %15476 = stablehlo.dot %15475, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %15477 = stablehlo.custom_call @byteir.softmax(%15476) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %15478:2 = stablehlo.custom_call @byteir.top_k(%15477) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %15479 = stablehlo.reduce(%15478#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %15480 = stablehlo.reshape %15479 : (tensor<3xf32>) -> tensor<3x1xf32> + %15481 = stablehlo.broadcast_in_dim %15478#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %15482 = stablehlo.broadcast_in_dim %15480, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %15483 = stablehlo.divide %15481, %15482 : tensor<3x2xf32> + %15484 = stablehlo.reshape %15478#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %15485 = stablehlo.broadcast_in_dim %15484, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %15486 = stablehlo.compare EQ, %15485, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %15487 = stablehlo.convert %15486 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %15488 = stablehlo.transpose %15487, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %15489 = stablehlo.slice %15488 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15490 = stablehlo.reshape %15489 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15491 = stablehlo.custom_call @byteir.non_zero(%15490) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5424 = tensor.dim %15491, %c0 : tensor + %15492 = arith.index_cast %dim_5424 : index to i64 + %from_elements_5425 = tensor.from_elements %15492, %c1_i64 : tensor<2xi64> + %15493 = stablehlo.real_dynamic_slice %15491, %c_22, %from_elements_5425, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5426 = tensor.dim %15493, %c0 : tensor + %15494 = arith.index_cast %dim_5426 : index to i64 + %from_elements_5427 = tensor.from_elements %15494 : tensor<1xi64> + %15495 = stablehlo.dynamic_reshape %15493, %from_elements_5427 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5428 = tensor.from_elements %15492, %c2_i64 : tensor<2xi64> + %15496 = stablehlo.real_dynamic_slice %15491, %c_24, %from_elements_5428, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5429 = tensor.dim %15496, %c0 : tensor + %15497 = arith.index_cast %dim_5429 : index to i64 + %from_elements_5430 = tensor.from_elements %15497 : tensor<1xi64> + %15498 = stablehlo.dynamic_reshape %15496, %from_elements_5430 : (tensor, tensor<1xi64>) -> tensor + %15499 = stablehlo.reshape %15475 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_5431 = tensor.dim %15498, %c0 : tensor + %15500 = arith.index_cast %dim_5431 : index to i64 + %from_elements_5432 = tensor.from_elements %15500, %c1_i64 : tensor<2xi64> + %15501 = stablehlo.dynamic_reshape %15498, %from_elements_5432 : (tensor, tensor<2xi64>) -> tensor + %dim_5433 = tensor.dim %15501, %c0 : tensor + %15502 = arith.index_cast %dim_5433 : index to i64 + %from_elements_5434 = tensor.from_elements %c1_i64, %15502, %c4096_i64 : tensor<3xi64> + %15503 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5434, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5435 = tensor.dim %15503, %c1 : tensor<1x?x4096xi64> + %15504 = arith.index_cast %dim_5435 : index to i64 + %from_elements_5436 = tensor.from_elements %c1_i64, %15504, %c4096_i64, %c1_i64 : tensor<4xi64> + %15505 = stablehlo.dynamic_reshape %15503, %from_elements_5436 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15506 = stablehlo.dynamic_broadcast_in_dim %15501, %from_elements_5434, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5437 = tensor.dim %15506, %c1 : tensor<1x?x4096xi64> + %15507 = arith.index_cast %dim_5437 : index to i64 + %from_elements_5438 = tensor.from_elements %c1_i64, %15507, %c4096_i64, %c1_i64 : tensor<4xi64> + %15508 = stablehlo.dynamic_reshape %15506, %from_elements_5438 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15509 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5434, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5439 = tensor.dim %15509, %c1 : tensor<1x?x4096xi64> + %15510 = arith.index_cast %dim_5439 : index to i64 + %from_elements_5440 = tensor.from_elements %c1_i64, %15510, %c4096_i64, %c1_i64 : tensor<4xi64> + %15511 = stablehlo.dynamic_reshape %15509, %from_elements_5440 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15512 = stablehlo.concatenate %15505, %15508, %15511, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15513 = "stablehlo.gather"(%15499, %15512) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15514 = shape.shape_of %15513 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15515 = shape.num_elements %15514 : tensor<3xindex> -> index + %15516 = stablehlo.compute_reshape_shape %15515, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15517 = stablehlo.dynamic_reshape %15513, %15516 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15518 = stablehlo.dot %15517, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15519 = stablehlo.logistic %15518 : tensor + %15520 = shape.shape_of %15519 : tensor -> tensor<2xindex> + %15521 = shape.shape_of %15518 : tensor -> tensor<2xindex> + %15522 = shape.cstr_broadcastable %15520, %15521 : tensor<2xindex>, tensor<2xindex> + %15523 = shape.assuming %15522 -> (tensor) { + %19688 = shape.broadcast %15520, %15521 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15519, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15518, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15524 = shape.shape_of %15523 : tensor -> tensor<2xindex> + %15525 = shape.cstr_broadcastable %15524, %15521 : tensor<2xindex>, tensor<2xindex> + %15526 = shape.assuming %15525 -> (tensor) { + %19688 = shape.broadcast %15524, %15521 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15523, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15518, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15527 = stablehlo.dot %15526, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %15528 = stablehlo.reshape %15483 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_5441 = tensor.dim %15498, %c0 : tensor + %15529 = arith.index_cast %dim_5441 : index to i64 + %from_elements_5442 = tensor.from_elements %15529, %c1_i64 : tensor<2xi64> + %15530 = stablehlo.dynamic_reshape %15498, %from_elements_5442 : (tensor, tensor<2xi64>) -> tensor + %dim_5443 = tensor.dim %15495, %c0 : tensor + %15531 = arith.index_cast %dim_5443 : index to i64 + %from_elements_5444 = tensor.from_elements %15531, %c1_i64 : tensor<2xi64> + %15532 = stablehlo.dynamic_reshape %15495, %from_elements_5444 : (tensor, tensor<2xi64>) -> tensor + %15533 = stablehlo.concatenate %15530, %15532, dim = 1 : (tensor, tensor) -> tensor + %15534 = "stablehlo.gather"(%15528, %15533) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15535 = shape.shape_of %15527 : tensor -> tensor<2xindex> + %15536 = shape.shape_of %15534 : tensor -> tensor<2xindex> + %15537 = shape.cstr_broadcastable %15535, %15536 : tensor<2xindex>, tensor<2xindex> + %15538 = shape.assuming %15537 -> (tensor) { + %19688 = shape.broadcast %15535, %15536 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15527, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15534, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15539 = shape.shape_of %15538 : tensor -> tensor<2xindex> + %15540 = stablehlo.dynamic_broadcast_in_dim %15538, %15539, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15541 = stablehlo.dynamic_broadcast_in_dim %213, %15539, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15542 = stablehlo.multiply %15540, %15541 : tensor + %dim_5445 = tensor.dim %15501, %c0 : tensor + %15543 = arith.index_cast %dim_5445 : index to i64 + %dim_5446 = tensor.dim %15538, %c0 : tensor + %15544 = arith.index_cast %dim_5446 : index to i64 + %15545 = arith.maxsi %15543, %15544 : i64 + %15546 = arith.index_cast %15545 : i64 to index + %from_elements_5447 = tensor.from_elements %15546, %c4096 : tensor<2xindex> + %15547 = stablehlo.dynamic_broadcast_in_dim %15501, %from_elements_5447, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5448 = tensor.dim %15547, %c0 : tensor + %15548 = arith.index_cast %dim_5448 : index to i64 + %from_elements_5449 = tensor.from_elements %15548, %c4096_i64 : tensor<2xi64> + %15549 = stablehlo.real_dynamic_slice %15542, %c_22, %from_elements_5449, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5450 = tensor.from_elements %15548, %c4096_i64, %c1_i64 : tensor<3xi64> + %15550 = stablehlo.dynamic_reshape %15547, %from_elements_5450 : (tensor, tensor<3xi64>) -> tensor + %15551 = stablehlo.dynamic_iota %from_elements_5450, dim = 1 : (tensor<3xi64>) -> tensor + %15552 = stablehlo.concatenate %15550, %15551, dim = 2 : (tensor, tensor) -> tensor + %15553 = "stablehlo.scatter"(%cst_2, %15552, %15549) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15554 = stablehlo.slice %15488 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15555 = stablehlo.reshape %15554 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15556 = stablehlo.custom_call @byteir.non_zero(%15555) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5451 = tensor.dim %15556, %c0 : tensor + %15557 = arith.index_cast %dim_5451 : index to i64 + %from_elements_5452 = tensor.from_elements %15557, %c1_i64 : tensor<2xi64> + %15558 = stablehlo.real_dynamic_slice %15556, %c_22, %from_elements_5452, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5453 = tensor.dim %15558, %c0 : tensor + %15559 = arith.index_cast %dim_5453 : index to i64 + %from_elements_5454 = tensor.from_elements %15559 : tensor<1xi64> + %15560 = stablehlo.dynamic_reshape %15558, %from_elements_5454 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5455 = tensor.from_elements %15557, %c2_i64 : tensor<2xi64> + %15561 = stablehlo.real_dynamic_slice %15556, %c_24, %from_elements_5455, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5456 = tensor.dim %15561, %c0 : tensor + %15562 = arith.index_cast %dim_5456 : index to i64 + %from_elements_5457 = tensor.from_elements %15562 : tensor<1xi64> + %15563 = stablehlo.dynamic_reshape %15561, %from_elements_5457 : (tensor, tensor<1xi64>) -> tensor + %dim_5458 = tensor.dim %15563, %c0 : tensor + %15564 = arith.index_cast %dim_5458 : index to i64 + %from_elements_5459 = tensor.from_elements %15564, %c1_i64 : tensor<2xi64> + %15565 = stablehlo.dynamic_reshape %15563, %from_elements_5459 : (tensor, tensor<2xi64>) -> tensor + %dim_5460 = tensor.dim %15565, %c0 : tensor + %15566 = arith.index_cast %dim_5460 : index to i64 + %from_elements_5461 = tensor.from_elements %c1_i64, %15566, %c4096_i64 : tensor<3xi64> + %15567 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5461, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5462 = tensor.dim %15567, %c1 : tensor<1x?x4096xi64> + %15568 = arith.index_cast %dim_5462 : index to i64 + %from_elements_5463 = tensor.from_elements %c1_i64, %15568, %c4096_i64, %c1_i64 : tensor<4xi64> + %15569 = stablehlo.dynamic_reshape %15567, %from_elements_5463 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15570 = stablehlo.dynamic_broadcast_in_dim %15565, %from_elements_5461, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5464 = tensor.dim %15570, %c1 : tensor<1x?x4096xi64> + %15571 = arith.index_cast %dim_5464 : index to i64 + %from_elements_5465 = tensor.from_elements %c1_i64, %15571, %c4096_i64, %c1_i64 : tensor<4xi64> + %15572 = stablehlo.dynamic_reshape %15570, %from_elements_5465 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15573 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5461, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5466 = tensor.dim %15573, %c1 : tensor<1x?x4096xi64> + %15574 = arith.index_cast %dim_5466 : index to i64 + %from_elements_5467 = tensor.from_elements %c1_i64, %15574, %c4096_i64, %c1_i64 : tensor<4xi64> + %15575 = stablehlo.dynamic_reshape %15573, %from_elements_5467 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15576 = stablehlo.concatenate %15569, %15572, %15575, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15577 = "stablehlo.gather"(%15499, %15576) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15578 = shape.shape_of %15577 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15579 = shape.num_elements %15578 : tensor<3xindex> -> index + %15580 = stablehlo.compute_reshape_shape %15579, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15581 = stablehlo.dynamic_reshape %15577, %15580 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15582 = stablehlo.dot %15581, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15583 = stablehlo.logistic %15582 : tensor + %15584 = shape.shape_of %15583 : tensor -> tensor<2xindex> + %15585 = shape.shape_of %15582 : tensor -> tensor<2xindex> + %15586 = shape.cstr_broadcastable %15584, %15585 : tensor<2xindex>, tensor<2xindex> + %15587 = shape.assuming %15586 -> (tensor) { + %19688 = shape.broadcast %15584, %15585 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15583, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15582, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15588 = shape.shape_of %15587 : tensor -> tensor<2xindex> + %15589 = shape.cstr_broadcastable %15588, %15585 : tensor<2xindex>, tensor<2xindex> + %15590 = shape.assuming %15589 -> (tensor) { + %19688 = shape.broadcast %15588, %15585 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15587, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15582, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15591 = stablehlo.dot %15590, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5468 = tensor.dim %15563, %c0 : tensor + %15592 = arith.index_cast %dim_5468 : index to i64 + %from_elements_5469 = tensor.from_elements %15592, %c1_i64 : tensor<2xi64> + %15593 = stablehlo.dynamic_reshape %15563, %from_elements_5469 : (tensor, tensor<2xi64>) -> tensor + %dim_5470 = tensor.dim %15560, %c0 : tensor + %15594 = arith.index_cast %dim_5470 : index to i64 + %from_elements_5471 = tensor.from_elements %15594, %c1_i64 : tensor<2xi64> + %15595 = stablehlo.dynamic_reshape %15560, %from_elements_5471 : (tensor, tensor<2xi64>) -> tensor + %15596 = stablehlo.concatenate %15593, %15595, dim = 1 : (tensor, tensor) -> tensor + %15597 = "stablehlo.gather"(%15528, %15596) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15598 = shape.shape_of %15591 : tensor -> tensor<2xindex> + %15599 = shape.shape_of %15597 : tensor -> tensor<2xindex> + %15600 = shape.cstr_broadcastable %15598, %15599 : tensor<2xindex>, tensor<2xindex> + %15601 = shape.assuming %15600 -> (tensor) { + %19688 = shape.broadcast %15598, %15599 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15591, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15597, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15602 = shape.shape_of %15601 : tensor -> tensor<2xindex> + %15603 = stablehlo.dynamic_broadcast_in_dim %15601, %15602, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15604 = stablehlo.dynamic_broadcast_in_dim %213, %15602, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15605 = stablehlo.multiply %15603, %15604 : tensor + %dim_5472 = tensor.dim %15565, %c0 : tensor + %15606 = arith.index_cast %dim_5472 : index to i64 + %dim_5473 = tensor.dim %15601, %c0 : tensor + %15607 = arith.index_cast %dim_5473 : index to i64 + %15608 = arith.maxsi %15606, %15607 : i64 + %15609 = arith.index_cast %15608 : i64 to index + %from_elements_5474 = tensor.from_elements %15609, %c4096 : tensor<2xindex> + %15610 = stablehlo.dynamic_broadcast_in_dim %15565, %from_elements_5474, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5475 = tensor.dim %15610, %c0 : tensor + %15611 = arith.index_cast %dim_5475 : index to i64 + %from_elements_5476 = tensor.from_elements %15611, %c4096_i64 : tensor<2xi64> + %15612 = stablehlo.real_dynamic_slice %15605, %c_22, %from_elements_5476, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5477 = tensor.from_elements %15611, %c4096_i64, %c1_i64 : tensor<3xi64> + %15613 = stablehlo.dynamic_reshape %15610, %from_elements_5477 : (tensor, tensor<3xi64>) -> tensor + %15614 = stablehlo.dynamic_iota %from_elements_5477, dim = 1 : (tensor<3xi64>) -> tensor + %15615 = stablehlo.concatenate %15613, %15614, dim = 2 : (tensor, tensor) -> tensor + %15616 = "stablehlo.scatter"(%15553, %15615, %15612) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15617 = stablehlo.slice %15488 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15618 = stablehlo.reshape %15617 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15619 = stablehlo.custom_call @byteir.non_zero(%15618) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5478 = tensor.dim %15619, %c0 : tensor + %15620 = arith.index_cast %dim_5478 : index to i64 + %from_elements_5479 = tensor.from_elements %15620, %c1_i64 : tensor<2xi64> + %15621 = stablehlo.real_dynamic_slice %15619, %c_22, %from_elements_5479, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5480 = tensor.dim %15621, %c0 : tensor + %15622 = arith.index_cast %dim_5480 : index to i64 + %from_elements_5481 = tensor.from_elements %15622 : tensor<1xi64> + %15623 = stablehlo.dynamic_reshape %15621, %from_elements_5481 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5482 = tensor.from_elements %15620, %c2_i64 : tensor<2xi64> + %15624 = stablehlo.real_dynamic_slice %15619, %c_24, %from_elements_5482, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5483 = tensor.dim %15624, %c0 : tensor + %15625 = arith.index_cast %dim_5483 : index to i64 + %from_elements_5484 = tensor.from_elements %15625 : tensor<1xi64> + %15626 = stablehlo.dynamic_reshape %15624, %from_elements_5484 : (tensor, tensor<1xi64>) -> tensor + %dim_5485 = tensor.dim %15626, %c0 : tensor + %15627 = arith.index_cast %dim_5485 : index to i64 + %from_elements_5486 = tensor.from_elements %15627, %c1_i64 : tensor<2xi64> + %15628 = stablehlo.dynamic_reshape %15626, %from_elements_5486 : (tensor, tensor<2xi64>) -> tensor + %dim_5487 = tensor.dim %15628, %c0 : tensor + %15629 = arith.index_cast %dim_5487 : index to i64 + %from_elements_5488 = tensor.from_elements %c1_i64, %15629, %c4096_i64 : tensor<3xi64> + %15630 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5488, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5489 = tensor.dim %15630, %c1 : tensor<1x?x4096xi64> + %15631 = arith.index_cast %dim_5489 : index to i64 + %from_elements_5490 = tensor.from_elements %c1_i64, %15631, %c4096_i64, %c1_i64 : tensor<4xi64> + %15632 = stablehlo.dynamic_reshape %15630, %from_elements_5490 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15633 = stablehlo.dynamic_broadcast_in_dim %15628, %from_elements_5488, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5491 = tensor.dim %15633, %c1 : tensor<1x?x4096xi64> + %15634 = arith.index_cast %dim_5491 : index to i64 + %from_elements_5492 = tensor.from_elements %c1_i64, %15634, %c4096_i64, %c1_i64 : tensor<4xi64> + %15635 = stablehlo.dynamic_reshape %15633, %from_elements_5492 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15636 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5488, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5493 = tensor.dim %15636, %c1 : tensor<1x?x4096xi64> + %15637 = arith.index_cast %dim_5493 : index to i64 + %from_elements_5494 = tensor.from_elements %c1_i64, %15637, %c4096_i64, %c1_i64 : tensor<4xi64> + %15638 = stablehlo.dynamic_reshape %15636, %from_elements_5494 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15639 = stablehlo.concatenate %15632, %15635, %15638, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15640 = "stablehlo.gather"(%15499, %15639) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15641 = shape.shape_of %15640 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15642 = shape.num_elements %15641 : tensor<3xindex> -> index + %15643 = stablehlo.compute_reshape_shape %15642, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15644 = stablehlo.dynamic_reshape %15640, %15643 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15645 = stablehlo.dot %15644, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15646 = stablehlo.logistic %15645 : tensor + %15647 = shape.shape_of %15646 : tensor -> tensor<2xindex> + %15648 = shape.shape_of %15645 : tensor -> tensor<2xindex> + %15649 = shape.cstr_broadcastable %15647, %15648 : tensor<2xindex>, tensor<2xindex> + %15650 = shape.assuming %15649 -> (tensor) { + %19688 = shape.broadcast %15647, %15648 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15646, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15645, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15651 = shape.shape_of %15650 : tensor -> tensor<2xindex> + %15652 = shape.cstr_broadcastable %15651, %15648 : tensor<2xindex>, tensor<2xindex> + %15653 = shape.assuming %15652 -> (tensor) { + %19688 = shape.broadcast %15651, %15648 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15650, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15645, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15654 = stablehlo.dot %15653, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5495 = tensor.dim %15626, %c0 : tensor + %15655 = arith.index_cast %dim_5495 : index to i64 + %from_elements_5496 = tensor.from_elements %15655, %c1_i64 : tensor<2xi64> + %15656 = stablehlo.dynamic_reshape %15626, %from_elements_5496 : (tensor, tensor<2xi64>) -> tensor + %dim_5497 = tensor.dim %15623, %c0 : tensor + %15657 = arith.index_cast %dim_5497 : index to i64 + %from_elements_5498 = tensor.from_elements %15657, %c1_i64 : tensor<2xi64> + %15658 = stablehlo.dynamic_reshape %15623, %from_elements_5498 : (tensor, tensor<2xi64>) -> tensor + %15659 = stablehlo.concatenate %15656, %15658, dim = 1 : (tensor, tensor) -> tensor + %15660 = "stablehlo.gather"(%15528, %15659) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15661 = shape.shape_of %15654 : tensor -> tensor<2xindex> + %15662 = shape.shape_of %15660 : tensor -> tensor<2xindex> + %15663 = shape.cstr_broadcastable %15661, %15662 : tensor<2xindex>, tensor<2xindex> + %15664 = shape.assuming %15663 -> (tensor) { + %19688 = shape.broadcast %15661, %15662 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15654, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15660, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15665 = shape.shape_of %15664 : tensor -> tensor<2xindex> + %15666 = stablehlo.dynamic_broadcast_in_dim %15664, %15665, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15667 = stablehlo.dynamic_broadcast_in_dim %213, %15665, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15668 = stablehlo.multiply %15666, %15667 : tensor + %dim_5499 = tensor.dim %15628, %c0 : tensor + %15669 = arith.index_cast %dim_5499 : index to i64 + %dim_5500 = tensor.dim %15664, %c0 : tensor + %15670 = arith.index_cast %dim_5500 : index to i64 + %15671 = arith.maxsi %15669, %15670 : i64 + %15672 = arith.index_cast %15671 : i64 to index + %from_elements_5501 = tensor.from_elements %15672, %c4096 : tensor<2xindex> + %15673 = stablehlo.dynamic_broadcast_in_dim %15628, %from_elements_5501, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5502 = tensor.dim %15673, %c0 : tensor + %15674 = arith.index_cast %dim_5502 : index to i64 + %from_elements_5503 = tensor.from_elements %15674, %c4096_i64 : tensor<2xi64> + %15675 = stablehlo.real_dynamic_slice %15668, %c_22, %from_elements_5503, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5504 = tensor.from_elements %15674, %c4096_i64, %c1_i64 : tensor<3xi64> + %15676 = stablehlo.dynamic_reshape %15673, %from_elements_5504 : (tensor, tensor<3xi64>) -> tensor + %15677 = stablehlo.dynamic_iota %from_elements_5504, dim = 1 : (tensor<3xi64>) -> tensor + %15678 = stablehlo.concatenate %15676, %15677, dim = 2 : (tensor, tensor) -> tensor + %15679 = "stablehlo.scatter"(%15616, %15678, %15675) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15680 = stablehlo.slice %15488 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15681 = stablehlo.reshape %15680 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15682 = stablehlo.custom_call @byteir.non_zero(%15681) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5505 = tensor.dim %15682, %c0 : tensor + %15683 = arith.index_cast %dim_5505 : index to i64 + %from_elements_5506 = tensor.from_elements %15683, %c1_i64 : tensor<2xi64> + %15684 = stablehlo.real_dynamic_slice %15682, %c_22, %from_elements_5506, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5507 = tensor.dim %15684, %c0 : tensor + %15685 = arith.index_cast %dim_5507 : index to i64 + %from_elements_5508 = tensor.from_elements %15685 : tensor<1xi64> + %15686 = stablehlo.dynamic_reshape %15684, %from_elements_5508 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5509 = tensor.from_elements %15683, %c2_i64 : tensor<2xi64> + %15687 = stablehlo.real_dynamic_slice %15682, %c_24, %from_elements_5509, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5510 = tensor.dim %15687, %c0 : tensor + %15688 = arith.index_cast %dim_5510 : index to i64 + %from_elements_5511 = tensor.from_elements %15688 : tensor<1xi64> + %15689 = stablehlo.dynamic_reshape %15687, %from_elements_5511 : (tensor, tensor<1xi64>) -> tensor + %dim_5512 = tensor.dim %15689, %c0 : tensor + %15690 = arith.index_cast %dim_5512 : index to i64 + %from_elements_5513 = tensor.from_elements %15690, %c1_i64 : tensor<2xi64> + %15691 = stablehlo.dynamic_reshape %15689, %from_elements_5513 : (tensor, tensor<2xi64>) -> tensor + %dim_5514 = tensor.dim %15691, %c0 : tensor + %15692 = arith.index_cast %dim_5514 : index to i64 + %from_elements_5515 = tensor.from_elements %c1_i64, %15692, %c4096_i64 : tensor<3xi64> + %15693 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5515, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5516 = tensor.dim %15693, %c1 : tensor<1x?x4096xi64> + %15694 = arith.index_cast %dim_5516 : index to i64 + %from_elements_5517 = tensor.from_elements %c1_i64, %15694, %c4096_i64, %c1_i64 : tensor<4xi64> + %15695 = stablehlo.dynamic_reshape %15693, %from_elements_5517 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15696 = stablehlo.dynamic_broadcast_in_dim %15691, %from_elements_5515, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5518 = tensor.dim %15696, %c1 : tensor<1x?x4096xi64> + %15697 = arith.index_cast %dim_5518 : index to i64 + %from_elements_5519 = tensor.from_elements %c1_i64, %15697, %c4096_i64, %c1_i64 : tensor<4xi64> + %15698 = stablehlo.dynamic_reshape %15696, %from_elements_5519 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15699 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5515, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5520 = tensor.dim %15699, %c1 : tensor<1x?x4096xi64> + %15700 = arith.index_cast %dim_5520 : index to i64 + %from_elements_5521 = tensor.from_elements %c1_i64, %15700, %c4096_i64, %c1_i64 : tensor<4xi64> + %15701 = stablehlo.dynamic_reshape %15699, %from_elements_5521 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15702 = stablehlo.concatenate %15695, %15698, %15701, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15703 = "stablehlo.gather"(%15499, %15702) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15704 = shape.shape_of %15703 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15705 = shape.num_elements %15704 : tensor<3xindex> -> index + %15706 = stablehlo.compute_reshape_shape %15705, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15707 = stablehlo.dynamic_reshape %15703, %15706 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15708 = stablehlo.dot %15707, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15709 = stablehlo.logistic %15708 : tensor + %15710 = shape.shape_of %15709 : tensor -> tensor<2xindex> + %15711 = shape.shape_of %15708 : tensor -> tensor<2xindex> + %15712 = shape.cstr_broadcastable %15710, %15711 : tensor<2xindex>, tensor<2xindex> + %15713 = shape.assuming %15712 -> (tensor) { + %19688 = shape.broadcast %15710, %15711 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15709, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15708, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15714 = shape.shape_of %15713 : tensor -> tensor<2xindex> + %15715 = shape.cstr_broadcastable %15714, %15711 : tensor<2xindex>, tensor<2xindex> + %15716 = shape.assuming %15715 -> (tensor) { + %19688 = shape.broadcast %15714, %15711 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15713, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15708, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15717 = stablehlo.dot %15716, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5522 = tensor.dim %15689, %c0 : tensor + %15718 = arith.index_cast %dim_5522 : index to i64 + %from_elements_5523 = tensor.from_elements %15718, %c1_i64 : tensor<2xi64> + %15719 = stablehlo.dynamic_reshape %15689, %from_elements_5523 : (tensor, tensor<2xi64>) -> tensor + %dim_5524 = tensor.dim %15686, %c0 : tensor + %15720 = arith.index_cast %dim_5524 : index to i64 + %from_elements_5525 = tensor.from_elements %15720, %c1_i64 : tensor<2xi64> + %15721 = stablehlo.dynamic_reshape %15686, %from_elements_5525 : (tensor, tensor<2xi64>) -> tensor + %15722 = stablehlo.concatenate %15719, %15721, dim = 1 : (tensor, tensor) -> tensor + %15723 = "stablehlo.gather"(%15528, %15722) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15724 = shape.shape_of %15717 : tensor -> tensor<2xindex> + %15725 = shape.shape_of %15723 : tensor -> tensor<2xindex> + %15726 = shape.cstr_broadcastable %15724, %15725 : tensor<2xindex>, tensor<2xindex> + %15727 = shape.assuming %15726 -> (tensor) { + %19688 = shape.broadcast %15724, %15725 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15717, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15723, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15728 = shape.shape_of %15727 : tensor -> tensor<2xindex> + %15729 = stablehlo.dynamic_broadcast_in_dim %15727, %15728, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15730 = stablehlo.dynamic_broadcast_in_dim %213, %15728, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15731 = stablehlo.multiply %15729, %15730 : tensor + %dim_5526 = tensor.dim %15691, %c0 : tensor + %15732 = arith.index_cast %dim_5526 : index to i64 + %dim_5527 = tensor.dim %15727, %c0 : tensor + %15733 = arith.index_cast %dim_5527 : index to i64 + %15734 = arith.maxsi %15732, %15733 : i64 + %15735 = arith.index_cast %15734 : i64 to index + %from_elements_5528 = tensor.from_elements %15735, %c4096 : tensor<2xindex> + %15736 = stablehlo.dynamic_broadcast_in_dim %15691, %from_elements_5528, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5529 = tensor.dim %15736, %c0 : tensor + %15737 = arith.index_cast %dim_5529 : index to i64 + %from_elements_5530 = tensor.from_elements %15737, %c4096_i64 : tensor<2xi64> + %15738 = stablehlo.real_dynamic_slice %15731, %c_22, %from_elements_5530, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5531 = tensor.from_elements %15737, %c4096_i64, %c1_i64 : tensor<3xi64> + %15739 = stablehlo.dynamic_reshape %15736, %from_elements_5531 : (tensor, tensor<3xi64>) -> tensor + %15740 = stablehlo.dynamic_iota %from_elements_5531, dim = 1 : (tensor<3xi64>) -> tensor + %15741 = stablehlo.concatenate %15739, %15740, dim = 2 : (tensor, tensor) -> tensor + %15742 = "stablehlo.scatter"(%15679, %15741, %15738) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15743 = stablehlo.slice %15488 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15744 = stablehlo.reshape %15743 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15745 = stablehlo.custom_call @byteir.non_zero(%15744) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5532 = tensor.dim %15745, %c0 : tensor + %15746 = arith.index_cast %dim_5532 : index to i64 + %from_elements_5533 = tensor.from_elements %15746, %c1_i64 : tensor<2xi64> + %15747 = stablehlo.real_dynamic_slice %15745, %c_22, %from_elements_5533, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5534 = tensor.dim %15747, %c0 : tensor + %15748 = arith.index_cast %dim_5534 : index to i64 + %from_elements_5535 = tensor.from_elements %15748 : tensor<1xi64> + %15749 = stablehlo.dynamic_reshape %15747, %from_elements_5535 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5536 = tensor.from_elements %15746, %c2_i64 : tensor<2xi64> + %15750 = stablehlo.real_dynamic_slice %15745, %c_24, %from_elements_5536, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5537 = tensor.dim %15750, %c0 : tensor + %15751 = arith.index_cast %dim_5537 : index to i64 + %from_elements_5538 = tensor.from_elements %15751 : tensor<1xi64> + %15752 = stablehlo.dynamic_reshape %15750, %from_elements_5538 : (tensor, tensor<1xi64>) -> tensor + %dim_5539 = tensor.dim %15752, %c0 : tensor + %15753 = arith.index_cast %dim_5539 : index to i64 + %from_elements_5540 = tensor.from_elements %15753, %c1_i64 : tensor<2xi64> + %15754 = stablehlo.dynamic_reshape %15752, %from_elements_5540 : (tensor, tensor<2xi64>) -> tensor + %dim_5541 = tensor.dim %15754, %c0 : tensor + %15755 = arith.index_cast %dim_5541 : index to i64 + %from_elements_5542 = tensor.from_elements %c1_i64, %15755, %c4096_i64 : tensor<3xi64> + %15756 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5542, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5543 = tensor.dim %15756, %c1 : tensor<1x?x4096xi64> + %15757 = arith.index_cast %dim_5543 : index to i64 + %from_elements_5544 = tensor.from_elements %c1_i64, %15757, %c4096_i64, %c1_i64 : tensor<4xi64> + %15758 = stablehlo.dynamic_reshape %15756, %from_elements_5544 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15759 = stablehlo.dynamic_broadcast_in_dim %15754, %from_elements_5542, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5545 = tensor.dim %15759, %c1 : tensor<1x?x4096xi64> + %15760 = arith.index_cast %dim_5545 : index to i64 + %from_elements_5546 = tensor.from_elements %c1_i64, %15760, %c4096_i64, %c1_i64 : tensor<4xi64> + %15761 = stablehlo.dynamic_reshape %15759, %from_elements_5546 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15762 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5542, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5547 = tensor.dim %15762, %c1 : tensor<1x?x4096xi64> + %15763 = arith.index_cast %dim_5547 : index to i64 + %from_elements_5548 = tensor.from_elements %c1_i64, %15763, %c4096_i64, %c1_i64 : tensor<4xi64> + %15764 = stablehlo.dynamic_reshape %15762, %from_elements_5548 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15765 = stablehlo.concatenate %15758, %15761, %15764, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15766 = "stablehlo.gather"(%15499, %15765) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15767 = shape.shape_of %15766 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15768 = shape.num_elements %15767 : tensor<3xindex> -> index + %15769 = stablehlo.compute_reshape_shape %15768, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15770 = stablehlo.dynamic_reshape %15766, %15769 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15771 = stablehlo.dot %15770, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15772 = stablehlo.logistic %15771 : tensor + %15773 = shape.shape_of %15772 : tensor -> tensor<2xindex> + %15774 = shape.shape_of %15771 : tensor -> tensor<2xindex> + %15775 = shape.cstr_broadcastable %15773, %15774 : tensor<2xindex>, tensor<2xindex> + %15776 = shape.assuming %15775 -> (tensor) { + %19688 = shape.broadcast %15773, %15774 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15772, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15771, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15777 = shape.shape_of %15776 : tensor -> tensor<2xindex> + %15778 = shape.cstr_broadcastable %15777, %15774 : tensor<2xindex>, tensor<2xindex> + %15779 = shape.assuming %15778 -> (tensor) { + %19688 = shape.broadcast %15777, %15774 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15776, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15771, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15780 = stablehlo.dot %15779, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5549 = tensor.dim %15752, %c0 : tensor + %15781 = arith.index_cast %dim_5549 : index to i64 + %from_elements_5550 = tensor.from_elements %15781, %c1_i64 : tensor<2xi64> + %15782 = stablehlo.dynamic_reshape %15752, %from_elements_5550 : (tensor, tensor<2xi64>) -> tensor + %dim_5551 = tensor.dim %15749, %c0 : tensor + %15783 = arith.index_cast %dim_5551 : index to i64 + %from_elements_5552 = tensor.from_elements %15783, %c1_i64 : tensor<2xi64> + %15784 = stablehlo.dynamic_reshape %15749, %from_elements_5552 : (tensor, tensor<2xi64>) -> tensor + %15785 = stablehlo.concatenate %15782, %15784, dim = 1 : (tensor, tensor) -> tensor + %15786 = "stablehlo.gather"(%15528, %15785) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15787 = shape.shape_of %15780 : tensor -> tensor<2xindex> + %15788 = shape.shape_of %15786 : tensor -> tensor<2xindex> + %15789 = shape.cstr_broadcastable %15787, %15788 : tensor<2xindex>, tensor<2xindex> + %15790 = shape.assuming %15789 -> (tensor) { + %19688 = shape.broadcast %15787, %15788 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15780, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15786, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15791 = shape.shape_of %15790 : tensor -> tensor<2xindex> + %15792 = stablehlo.dynamic_broadcast_in_dim %15790, %15791, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15793 = stablehlo.dynamic_broadcast_in_dim %213, %15791, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15794 = stablehlo.multiply %15792, %15793 : tensor + %dim_5553 = tensor.dim %15754, %c0 : tensor + %15795 = arith.index_cast %dim_5553 : index to i64 + %dim_5554 = tensor.dim %15790, %c0 : tensor + %15796 = arith.index_cast %dim_5554 : index to i64 + %15797 = arith.maxsi %15795, %15796 : i64 + %15798 = arith.index_cast %15797 : i64 to index + %from_elements_5555 = tensor.from_elements %15798, %c4096 : tensor<2xindex> + %15799 = stablehlo.dynamic_broadcast_in_dim %15754, %from_elements_5555, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5556 = tensor.dim %15799, %c0 : tensor + %15800 = arith.index_cast %dim_5556 : index to i64 + %from_elements_5557 = tensor.from_elements %15800, %c4096_i64 : tensor<2xi64> + %15801 = stablehlo.real_dynamic_slice %15794, %c_22, %from_elements_5557, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5558 = tensor.from_elements %15800, %c4096_i64, %c1_i64 : tensor<3xi64> + %15802 = stablehlo.dynamic_reshape %15799, %from_elements_5558 : (tensor, tensor<3xi64>) -> tensor + %15803 = stablehlo.dynamic_iota %from_elements_5558, dim = 1 : (tensor<3xi64>) -> tensor + %15804 = stablehlo.concatenate %15802, %15803, dim = 2 : (tensor, tensor) -> tensor + %15805 = "stablehlo.scatter"(%15742, %15804, %15801) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15806 = stablehlo.slice %15488 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15807 = stablehlo.reshape %15806 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15808 = stablehlo.custom_call @byteir.non_zero(%15807) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5559 = tensor.dim %15808, %c0 : tensor + %15809 = arith.index_cast %dim_5559 : index to i64 + %from_elements_5560 = tensor.from_elements %15809, %c1_i64 : tensor<2xi64> + %15810 = stablehlo.real_dynamic_slice %15808, %c_22, %from_elements_5560, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5561 = tensor.dim %15810, %c0 : tensor + %15811 = arith.index_cast %dim_5561 : index to i64 + %from_elements_5562 = tensor.from_elements %15811 : tensor<1xi64> + %15812 = stablehlo.dynamic_reshape %15810, %from_elements_5562 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5563 = tensor.from_elements %15809, %c2_i64 : tensor<2xi64> + %15813 = stablehlo.real_dynamic_slice %15808, %c_24, %from_elements_5563, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5564 = tensor.dim %15813, %c0 : tensor + %15814 = arith.index_cast %dim_5564 : index to i64 + %from_elements_5565 = tensor.from_elements %15814 : tensor<1xi64> + %15815 = stablehlo.dynamic_reshape %15813, %from_elements_5565 : (tensor, tensor<1xi64>) -> tensor + %dim_5566 = tensor.dim %15815, %c0 : tensor + %15816 = arith.index_cast %dim_5566 : index to i64 + %from_elements_5567 = tensor.from_elements %15816, %c1_i64 : tensor<2xi64> + %15817 = stablehlo.dynamic_reshape %15815, %from_elements_5567 : (tensor, tensor<2xi64>) -> tensor + %dim_5568 = tensor.dim %15817, %c0 : tensor + %15818 = arith.index_cast %dim_5568 : index to i64 + %from_elements_5569 = tensor.from_elements %c1_i64, %15818, %c4096_i64 : tensor<3xi64> + %15819 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5569, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5570 = tensor.dim %15819, %c1 : tensor<1x?x4096xi64> + %15820 = arith.index_cast %dim_5570 : index to i64 + %from_elements_5571 = tensor.from_elements %c1_i64, %15820, %c4096_i64, %c1_i64 : tensor<4xi64> + %15821 = stablehlo.dynamic_reshape %15819, %from_elements_5571 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15822 = stablehlo.dynamic_broadcast_in_dim %15817, %from_elements_5569, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5572 = tensor.dim %15822, %c1 : tensor<1x?x4096xi64> + %15823 = arith.index_cast %dim_5572 : index to i64 + %from_elements_5573 = tensor.from_elements %c1_i64, %15823, %c4096_i64, %c1_i64 : tensor<4xi64> + %15824 = stablehlo.dynamic_reshape %15822, %from_elements_5573 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15825 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5569, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5574 = tensor.dim %15825, %c1 : tensor<1x?x4096xi64> + %15826 = arith.index_cast %dim_5574 : index to i64 + %from_elements_5575 = tensor.from_elements %c1_i64, %15826, %c4096_i64, %c1_i64 : tensor<4xi64> + %15827 = stablehlo.dynamic_reshape %15825, %from_elements_5575 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15828 = stablehlo.concatenate %15821, %15824, %15827, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15829 = "stablehlo.gather"(%15499, %15828) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15830 = shape.shape_of %15829 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15831 = shape.num_elements %15830 : tensor<3xindex> -> index + %15832 = stablehlo.compute_reshape_shape %15831, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15833 = stablehlo.dynamic_reshape %15829, %15832 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15834 = stablehlo.dot %15833, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15835 = stablehlo.logistic %15834 : tensor + %15836 = shape.shape_of %15835 : tensor -> tensor<2xindex> + %15837 = shape.shape_of %15834 : tensor -> tensor<2xindex> + %15838 = shape.cstr_broadcastable %15836, %15837 : tensor<2xindex>, tensor<2xindex> + %15839 = shape.assuming %15838 -> (tensor) { + %19688 = shape.broadcast %15836, %15837 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15835, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15834, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15840 = shape.shape_of %15839 : tensor -> tensor<2xindex> + %15841 = shape.cstr_broadcastable %15840, %15837 : tensor<2xindex>, tensor<2xindex> + %15842 = shape.assuming %15841 -> (tensor) { + %19688 = shape.broadcast %15840, %15837 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15839, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15834, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15843 = stablehlo.dot %15842, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5576 = tensor.dim %15815, %c0 : tensor + %15844 = arith.index_cast %dim_5576 : index to i64 + %from_elements_5577 = tensor.from_elements %15844, %c1_i64 : tensor<2xi64> + %15845 = stablehlo.dynamic_reshape %15815, %from_elements_5577 : (tensor, tensor<2xi64>) -> tensor + %dim_5578 = tensor.dim %15812, %c0 : tensor + %15846 = arith.index_cast %dim_5578 : index to i64 + %from_elements_5579 = tensor.from_elements %15846, %c1_i64 : tensor<2xi64> + %15847 = stablehlo.dynamic_reshape %15812, %from_elements_5579 : (tensor, tensor<2xi64>) -> tensor + %15848 = stablehlo.concatenate %15845, %15847, dim = 1 : (tensor, tensor) -> tensor + %15849 = "stablehlo.gather"(%15528, %15848) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15850 = shape.shape_of %15843 : tensor -> tensor<2xindex> + %15851 = shape.shape_of %15849 : tensor -> tensor<2xindex> + %15852 = shape.cstr_broadcastable %15850, %15851 : tensor<2xindex>, tensor<2xindex> + %15853 = shape.assuming %15852 -> (tensor) { + %19688 = shape.broadcast %15850, %15851 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15843, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15849, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15854 = shape.shape_of %15853 : tensor -> tensor<2xindex> + %15855 = stablehlo.dynamic_broadcast_in_dim %15853, %15854, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15856 = stablehlo.dynamic_broadcast_in_dim %213, %15854, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15857 = stablehlo.multiply %15855, %15856 : tensor + %dim_5580 = tensor.dim %15817, %c0 : tensor + %15858 = arith.index_cast %dim_5580 : index to i64 + %dim_5581 = tensor.dim %15853, %c0 : tensor + %15859 = arith.index_cast %dim_5581 : index to i64 + %15860 = arith.maxsi %15858, %15859 : i64 + %15861 = arith.index_cast %15860 : i64 to index + %from_elements_5582 = tensor.from_elements %15861, %c4096 : tensor<2xindex> + %15862 = stablehlo.dynamic_broadcast_in_dim %15817, %from_elements_5582, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5583 = tensor.dim %15862, %c0 : tensor + %15863 = arith.index_cast %dim_5583 : index to i64 + %from_elements_5584 = tensor.from_elements %15863, %c4096_i64 : tensor<2xi64> + %15864 = stablehlo.real_dynamic_slice %15857, %c_22, %from_elements_5584, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5585 = tensor.from_elements %15863, %c4096_i64, %c1_i64 : tensor<3xi64> + %15865 = stablehlo.dynamic_reshape %15862, %from_elements_5585 : (tensor, tensor<3xi64>) -> tensor + %15866 = stablehlo.dynamic_iota %from_elements_5585, dim = 1 : (tensor<3xi64>) -> tensor + %15867 = stablehlo.concatenate %15865, %15866, dim = 2 : (tensor, tensor) -> tensor + %15868 = "stablehlo.scatter"(%15805, %15867, %15864) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15869 = stablehlo.slice %15488 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15870 = stablehlo.reshape %15869 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15871 = stablehlo.custom_call @byteir.non_zero(%15870) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5586 = tensor.dim %15871, %c0 : tensor + %15872 = arith.index_cast %dim_5586 : index to i64 + %from_elements_5587 = tensor.from_elements %15872, %c1_i64 : tensor<2xi64> + %15873 = stablehlo.real_dynamic_slice %15871, %c_22, %from_elements_5587, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5588 = tensor.dim %15873, %c0 : tensor + %15874 = arith.index_cast %dim_5588 : index to i64 + %from_elements_5589 = tensor.from_elements %15874 : tensor<1xi64> + %15875 = stablehlo.dynamic_reshape %15873, %from_elements_5589 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5590 = tensor.from_elements %15872, %c2_i64 : tensor<2xi64> + %15876 = stablehlo.real_dynamic_slice %15871, %c_24, %from_elements_5590, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5591 = tensor.dim %15876, %c0 : tensor + %15877 = arith.index_cast %dim_5591 : index to i64 + %from_elements_5592 = tensor.from_elements %15877 : tensor<1xi64> + %15878 = stablehlo.dynamic_reshape %15876, %from_elements_5592 : (tensor, tensor<1xi64>) -> tensor + %dim_5593 = tensor.dim %15878, %c0 : tensor + %15879 = arith.index_cast %dim_5593 : index to i64 + %from_elements_5594 = tensor.from_elements %15879, %c1_i64 : tensor<2xi64> + %15880 = stablehlo.dynamic_reshape %15878, %from_elements_5594 : (tensor, tensor<2xi64>) -> tensor + %dim_5595 = tensor.dim %15880, %c0 : tensor + %15881 = arith.index_cast %dim_5595 : index to i64 + %from_elements_5596 = tensor.from_elements %c1_i64, %15881, %c4096_i64 : tensor<3xi64> + %15882 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5596, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5597 = tensor.dim %15882, %c1 : tensor<1x?x4096xi64> + %15883 = arith.index_cast %dim_5597 : index to i64 + %from_elements_5598 = tensor.from_elements %c1_i64, %15883, %c4096_i64, %c1_i64 : tensor<4xi64> + %15884 = stablehlo.dynamic_reshape %15882, %from_elements_5598 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15885 = stablehlo.dynamic_broadcast_in_dim %15880, %from_elements_5596, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5599 = tensor.dim %15885, %c1 : tensor<1x?x4096xi64> + %15886 = arith.index_cast %dim_5599 : index to i64 + %from_elements_5600 = tensor.from_elements %c1_i64, %15886, %c4096_i64, %c1_i64 : tensor<4xi64> + %15887 = stablehlo.dynamic_reshape %15885, %from_elements_5600 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15888 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5596, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5601 = tensor.dim %15888, %c1 : tensor<1x?x4096xi64> + %15889 = arith.index_cast %dim_5601 : index to i64 + %from_elements_5602 = tensor.from_elements %c1_i64, %15889, %c4096_i64, %c1_i64 : tensor<4xi64> + %15890 = stablehlo.dynamic_reshape %15888, %from_elements_5602 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15891 = stablehlo.concatenate %15884, %15887, %15890, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15892 = "stablehlo.gather"(%15499, %15891) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15893 = shape.shape_of %15892 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15894 = shape.num_elements %15893 : tensor<3xindex> -> index + %15895 = stablehlo.compute_reshape_shape %15894, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15896 = stablehlo.dynamic_reshape %15892, %15895 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15897 = stablehlo.dot %15896, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15898 = stablehlo.logistic %15897 : tensor + %15899 = shape.shape_of %15898 : tensor -> tensor<2xindex> + %15900 = shape.shape_of %15897 : tensor -> tensor<2xindex> + %15901 = shape.cstr_broadcastable %15899, %15900 : tensor<2xindex>, tensor<2xindex> + %15902 = shape.assuming %15901 -> (tensor) { + %19688 = shape.broadcast %15899, %15900 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15898, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15897, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15903 = shape.shape_of %15902 : tensor -> tensor<2xindex> + %15904 = shape.cstr_broadcastable %15903, %15900 : tensor<2xindex>, tensor<2xindex> + %15905 = shape.assuming %15904 -> (tensor) { + %19688 = shape.broadcast %15903, %15900 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15902, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15897, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15906 = stablehlo.dot %15905, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5603 = tensor.dim %15878, %c0 : tensor + %15907 = arith.index_cast %dim_5603 : index to i64 + %from_elements_5604 = tensor.from_elements %15907, %c1_i64 : tensor<2xi64> + %15908 = stablehlo.dynamic_reshape %15878, %from_elements_5604 : (tensor, tensor<2xi64>) -> tensor + %dim_5605 = tensor.dim %15875, %c0 : tensor + %15909 = arith.index_cast %dim_5605 : index to i64 + %from_elements_5606 = tensor.from_elements %15909, %c1_i64 : tensor<2xi64> + %15910 = stablehlo.dynamic_reshape %15875, %from_elements_5606 : (tensor, tensor<2xi64>) -> tensor + %15911 = stablehlo.concatenate %15908, %15910, dim = 1 : (tensor, tensor) -> tensor + %15912 = "stablehlo.gather"(%15528, %15911) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15913 = shape.shape_of %15906 : tensor -> tensor<2xindex> + %15914 = shape.shape_of %15912 : tensor -> tensor<2xindex> + %15915 = shape.cstr_broadcastable %15913, %15914 : tensor<2xindex>, tensor<2xindex> + %15916 = shape.assuming %15915 -> (tensor) { + %19688 = shape.broadcast %15913, %15914 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15906, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15912, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15917 = shape.shape_of %15916 : tensor -> tensor<2xindex> + %15918 = stablehlo.dynamic_broadcast_in_dim %15916, %15917, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15919 = stablehlo.dynamic_broadcast_in_dim %213, %15917, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15920 = stablehlo.multiply %15918, %15919 : tensor + %dim_5607 = tensor.dim %15880, %c0 : tensor + %15921 = arith.index_cast %dim_5607 : index to i64 + %dim_5608 = tensor.dim %15916, %c0 : tensor + %15922 = arith.index_cast %dim_5608 : index to i64 + %15923 = arith.maxsi %15921, %15922 : i64 + %15924 = arith.index_cast %15923 : i64 to index + %from_elements_5609 = tensor.from_elements %15924, %c4096 : tensor<2xindex> + %15925 = stablehlo.dynamic_broadcast_in_dim %15880, %from_elements_5609, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5610 = tensor.dim %15925, %c0 : tensor + %15926 = arith.index_cast %dim_5610 : index to i64 + %from_elements_5611 = tensor.from_elements %15926, %c4096_i64 : tensor<2xi64> + %15927 = stablehlo.real_dynamic_slice %15920, %c_22, %from_elements_5611, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5612 = tensor.from_elements %15926, %c4096_i64, %c1_i64 : tensor<3xi64> + %15928 = stablehlo.dynamic_reshape %15925, %from_elements_5612 : (tensor, tensor<3xi64>) -> tensor + %15929 = stablehlo.dynamic_iota %from_elements_5612, dim = 1 : (tensor<3xi64>) -> tensor + %15930 = stablehlo.concatenate %15928, %15929, dim = 2 : (tensor, tensor) -> tensor + %15931 = "stablehlo.scatter"(%15868, %15930, %15927) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15932 = stablehlo.slice %15488 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %15933 = stablehlo.reshape %15932 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %15934 = stablehlo.custom_call @byteir.non_zero(%15933) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5613 = tensor.dim %15934, %c0 : tensor + %15935 = arith.index_cast %dim_5613 : index to i64 + %from_elements_5614 = tensor.from_elements %15935, %c1_i64 : tensor<2xi64> + %15936 = stablehlo.real_dynamic_slice %15934, %c_22, %from_elements_5614, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5615 = tensor.dim %15936, %c0 : tensor + %15937 = arith.index_cast %dim_5615 : index to i64 + %from_elements_5616 = tensor.from_elements %15937 : tensor<1xi64> + %15938 = stablehlo.dynamic_reshape %15936, %from_elements_5616 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5617 = tensor.from_elements %15935, %c2_i64 : tensor<2xi64> + %15939 = stablehlo.real_dynamic_slice %15934, %c_24, %from_elements_5617, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5618 = tensor.dim %15939, %c0 : tensor + %15940 = arith.index_cast %dim_5618 : index to i64 + %from_elements_5619 = tensor.from_elements %15940 : tensor<1xi64> + %15941 = stablehlo.dynamic_reshape %15939, %from_elements_5619 : (tensor, tensor<1xi64>) -> tensor + %dim_5620 = tensor.dim %15941, %c0 : tensor + %15942 = arith.index_cast %dim_5620 : index to i64 + %from_elements_5621 = tensor.from_elements %15942, %c1_i64 : tensor<2xi64> + %15943 = stablehlo.dynamic_reshape %15941, %from_elements_5621 : (tensor, tensor<2xi64>) -> tensor + %dim_5622 = tensor.dim %15943, %c0 : tensor + %15944 = arith.index_cast %dim_5622 : index to i64 + %from_elements_5623 = tensor.from_elements %c1_i64, %15944, %c4096_i64 : tensor<3xi64> + %15945 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5623, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5624 = tensor.dim %15945, %c1 : tensor<1x?x4096xi64> + %15946 = arith.index_cast %dim_5624 : index to i64 + %from_elements_5625 = tensor.from_elements %c1_i64, %15946, %c4096_i64, %c1_i64 : tensor<4xi64> + %15947 = stablehlo.dynamic_reshape %15945, %from_elements_5625 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15948 = stablehlo.dynamic_broadcast_in_dim %15943, %from_elements_5623, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5626 = tensor.dim %15948, %c1 : tensor<1x?x4096xi64> + %15949 = arith.index_cast %dim_5626 : index to i64 + %from_elements_5627 = tensor.from_elements %c1_i64, %15949, %c4096_i64, %c1_i64 : tensor<4xi64> + %15950 = stablehlo.dynamic_reshape %15948, %from_elements_5627 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15951 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5623, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5628 = tensor.dim %15951, %c1 : tensor<1x?x4096xi64> + %15952 = arith.index_cast %dim_5628 : index to i64 + %from_elements_5629 = tensor.from_elements %c1_i64, %15952, %c4096_i64, %c1_i64 : tensor<4xi64> + %15953 = stablehlo.dynamic_reshape %15951, %from_elements_5629 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %15954 = stablehlo.concatenate %15947, %15950, %15953, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %15955 = "stablehlo.gather"(%15499, %15954) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %15956 = shape.shape_of %15955 : tensor<1x?x4096xf32> -> tensor<3xindex> + %15957 = shape.num_elements %15956 : tensor<3xindex> -> index + %15958 = stablehlo.compute_reshape_shape %15957, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %15959 = stablehlo.dynamic_reshape %15955, %15958 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %15960 = stablehlo.dot %15959, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %15961 = stablehlo.logistic %15960 : tensor + %15962 = shape.shape_of %15961 : tensor -> tensor<2xindex> + %15963 = shape.shape_of %15960 : tensor -> tensor<2xindex> + %15964 = shape.cstr_broadcastable %15962, %15963 : tensor<2xindex>, tensor<2xindex> + %15965 = shape.assuming %15964 -> (tensor) { + %19688 = shape.broadcast %15962, %15963 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15961, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15960, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15966 = shape.shape_of %15965 : tensor -> tensor<2xindex> + %15967 = shape.cstr_broadcastable %15966, %15963 : tensor<2xindex>, tensor<2xindex> + %15968 = shape.assuming %15967 -> (tensor) { + %19688 = shape.broadcast %15966, %15963 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15965, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15960, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15969 = stablehlo.dot %15968, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5630 = tensor.dim %15941, %c0 : tensor + %15970 = arith.index_cast %dim_5630 : index to i64 + %from_elements_5631 = tensor.from_elements %15970, %c1_i64 : tensor<2xi64> + %15971 = stablehlo.dynamic_reshape %15941, %from_elements_5631 : (tensor, tensor<2xi64>) -> tensor + %dim_5632 = tensor.dim %15938, %c0 : tensor + %15972 = arith.index_cast %dim_5632 : index to i64 + %from_elements_5633 = tensor.from_elements %15972, %c1_i64 : tensor<2xi64> + %15973 = stablehlo.dynamic_reshape %15938, %from_elements_5633 : (tensor, tensor<2xi64>) -> tensor + %15974 = stablehlo.concatenate %15971, %15973, dim = 1 : (tensor, tensor) -> tensor + %15975 = "stablehlo.gather"(%15528, %15974) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %15976 = shape.shape_of %15969 : tensor -> tensor<2xindex> + %15977 = shape.shape_of %15975 : tensor -> tensor<2xindex> + %15978 = shape.cstr_broadcastable %15976, %15977 : tensor<2xindex>, tensor<2xindex> + %15979 = shape.assuming %15978 -> (tensor) { + %19688 = shape.broadcast %15976, %15977 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %15969, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %15975, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %15980 = shape.shape_of %15979 : tensor -> tensor<2xindex> + %15981 = stablehlo.dynamic_broadcast_in_dim %15979, %15980, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %15982 = stablehlo.dynamic_broadcast_in_dim %213, %15980, dims = [] : (tensor, tensor<2xindex>) -> tensor + %15983 = stablehlo.multiply %15981, %15982 : tensor + %dim_5634 = tensor.dim %15943, %c0 : tensor + %15984 = arith.index_cast %dim_5634 : index to i64 + %dim_5635 = tensor.dim %15979, %c0 : tensor + %15985 = arith.index_cast %dim_5635 : index to i64 + %15986 = arith.maxsi %15984, %15985 : i64 + %15987 = arith.index_cast %15986 : i64 to index + %from_elements_5636 = tensor.from_elements %15987, %c4096 : tensor<2xindex> + %15988 = stablehlo.dynamic_broadcast_in_dim %15943, %from_elements_5636, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5637 = tensor.dim %15988, %c0 : tensor + %15989 = arith.index_cast %dim_5637 : index to i64 + %from_elements_5638 = tensor.from_elements %15989, %c4096_i64 : tensor<2xi64> + %15990 = stablehlo.real_dynamic_slice %15983, %c_22, %from_elements_5638, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5639 = tensor.from_elements %15989, %c4096_i64, %c1_i64 : tensor<3xi64> + %15991 = stablehlo.dynamic_reshape %15988, %from_elements_5639 : (tensor, tensor<3xi64>) -> tensor + %15992 = stablehlo.dynamic_iota %from_elements_5639, dim = 1 : (tensor<3xi64>) -> tensor + %15993 = stablehlo.concatenate %15991, %15992, dim = 2 : (tensor, tensor) -> tensor + %15994 = "stablehlo.scatter"(%15931, %15993, %15990) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %15995 = stablehlo.reshape %15994 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %15996 = stablehlo.add %15461, %15995 : tensor<3x1x4096xf32> + %15997 = stablehlo.broadcast_in_dim %15996, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %15998 = stablehlo.power %15997, %15 : tensor<3x1x4096xf32> + %15999 = stablehlo.reduce(%15998 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %16000 = stablehlo.reshape %15999 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %16001 = stablehlo.broadcast_in_dim %16000, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %16002 = stablehlo.divide %16001, %21 : tensor<3x1x1xf32> + %16003 = stablehlo.broadcast_in_dim %16002, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %16004 = stablehlo.add %16003, %25 : tensor<3x1x1xf32> + %16005 = stablehlo.rsqrt %16004 : tensor<3x1x1xf32> + %16006 = stablehlo.broadcast_in_dim %16005, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %16007 = stablehlo.multiply %15997, %16006 : tensor<3x1x4096xf32> + %16008 = stablehlo.broadcast_in_dim %16007, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %16009 = stablehlo.multiply %16008, %31 : tensor<3x1x4096xf32> + %16010 = stablehlo.reshape %16009 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %16011 = stablehlo.dot %16010, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %16012 = stablehlo.reshape %16011 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %16013 = stablehlo.dot %16010, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %16014 = stablehlo.reshape %16013 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %16015 = stablehlo.reshape %16012 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %16016 = stablehlo.transpose %16015, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %16017 = stablehlo.reshape %16014 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %16018 = stablehlo.transpose %16017, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %16019 = stablehlo.slice %arg52 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %16020 = stablehlo.slice %arg53 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %16021 = "stablehlo.gather"(%16019, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %16022 = stablehlo.reshape %16021 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %16023 = "stablehlo.gather"(%16020, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %16024 = stablehlo.reshape %16023 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %16025 = stablehlo.broadcast_in_dim %16016, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %16026 = stablehlo.broadcast_in_dim %16022, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %16027 = stablehlo.multiply %16025, %16026 : tensor<3x32x1x128xf32> + %16028 = stablehlo.slice %16016 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %16029 = stablehlo.slice %16016 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %16030 = stablehlo.negate %16029 : tensor<3x32x1x64xf32> + %16031 = stablehlo.concatenate %16030, %16028, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %16032 = stablehlo.broadcast_in_dim %16031, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %16033 = stablehlo.broadcast_in_dim %16024, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %16034 = stablehlo.multiply %16032, %16033 : tensor<3x32x1x128xf32> + %16035 = stablehlo.add %16027, %16034 : tensor<3x32x1x128xf32> + %16036 = stablehlo.broadcast_in_dim %16018, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %16037 = stablehlo.broadcast_in_dim %16022, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %16038 = stablehlo.multiply %16036, %16037 : tensor<3x8x1x128xf32> + %16039 = stablehlo.slice %16018 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %16040 = stablehlo.slice %16018 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %16041 = stablehlo.negate %16040 : tensor<3x8x1x64xf32> + %16042 = stablehlo.concatenate %16041, %16039, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %16043 = stablehlo.broadcast_in_dim %16042, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %16044 = stablehlo.broadcast_in_dim %16024, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %16045 = stablehlo.multiply %16043, %16044 : tensor<3x8x1x128xf32> + %16046 = stablehlo.add %16038, %16045 : tensor<3x8x1x128xf32> + %16047 = stablehlo.concatenate %arg117, %16046, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %16048 = stablehlo.concatenate %arg118, %16018, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %16049 = stablehlo.reshape %16047 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %16050 = stablehlo.broadcast_in_dim %16049, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %16051 = stablehlo.reshape %16050 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %16052 = stablehlo.reshape %16048 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %16053 = stablehlo.broadcast_in_dim %16052, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %16054 = stablehlo.reshape %16053 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %16055 = stablehlo.transpose %16051, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %16056 = stablehlo.reshape %16035 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %16057 = stablehlo.reshape %16055 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %16058 = stablehlo.broadcast_in_dim %16057, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %16059 = stablehlo.dot_general %16056, %16058, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %16060 = stablehlo.reshape %16059 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %16061 = stablehlo.broadcast_in_dim %16060, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %16062 = stablehlo.divide %16061, %89 : tensor<3x32x1x8xf32> + %16063 = stablehlo.custom_call @byteir.softmax(%16062) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %16064 = stablehlo.reshape %16063 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %16065 = stablehlo.reshape %16054 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %16066 = stablehlo.broadcast_in_dim %16065, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %16067 = stablehlo.dot_general %16064, %16066, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %16068 = stablehlo.reshape %16067 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %16069 = stablehlo.transpose %16068, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %16070 = stablehlo.reshape %16069 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %16071 = stablehlo.reshape %16070 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %16072 = stablehlo.dot %16071, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %16073 = stablehlo.reshape %16072 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %16074 = stablehlo.add %15996, %16073 : tensor<3x1x4096xf32> + %16075 = stablehlo.broadcast_in_dim %16074, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %16076 = stablehlo.power %16075, %15 : tensor<3x1x4096xf32> + %16077 = stablehlo.reduce(%16076 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %16078 = stablehlo.reshape %16077 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %16079 = stablehlo.broadcast_in_dim %16078, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %16080 = stablehlo.divide %16079, %21 : tensor<3x1x1xf32> + %16081 = stablehlo.broadcast_in_dim %16080, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %16082 = stablehlo.add %16081, %25 : tensor<3x1x1xf32> + %16083 = stablehlo.rsqrt %16082 : tensor<3x1x1xf32> + %16084 = stablehlo.broadcast_in_dim %16083, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %16085 = stablehlo.multiply %16075, %16084 : tensor<3x1x4096xf32> + %16086 = stablehlo.broadcast_in_dim %16085, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %16087 = stablehlo.multiply %16086, %31 : tensor<3x1x4096xf32> + %16088 = stablehlo.reshape %16087 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %16089 = stablehlo.dot %16088, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %16090 = stablehlo.custom_call @byteir.softmax(%16089) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %16091:2 = stablehlo.custom_call @byteir.top_k(%16090) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %16092 = stablehlo.reduce(%16091#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %16093 = stablehlo.reshape %16092 : (tensor<3xf32>) -> tensor<3x1xf32> + %16094 = stablehlo.broadcast_in_dim %16091#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %16095 = stablehlo.broadcast_in_dim %16093, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %16096 = stablehlo.divide %16094, %16095 : tensor<3x2xf32> + %16097 = stablehlo.reshape %16091#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %16098 = stablehlo.broadcast_in_dim %16097, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %16099 = stablehlo.compare EQ, %16098, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %16100 = stablehlo.convert %16099 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %16101 = stablehlo.transpose %16100, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %16102 = stablehlo.slice %16101 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16103 = stablehlo.reshape %16102 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16104 = stablehlo.custom_call @byteir.non_zero(%16103) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5640 = tensor.dim %16104, %c0 : tensor + %16105 = arith.index_cast %dim_5640 : index to i64 + %from_elements_5641 = tensor.from_elements %16105, %c1_i64 : tensor<2xi64> + %16106 = stablehlo.real_dynamic_slice %16104, %c_22, %from_elements_5641, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5642 = tensor.dim %16106, %c0 : tensor + %16107 = arith.index_cast %dim_5642 : index to i64 + %from_elements_5643 = tensor.from_elements %16107 : tensor<1xi64> + %16108 = stablehlo.dynamic_reshape %16106, %from_elements_5643 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5644 = tensor.from_elements %16105, %c2_i64 : tensor<2xi64> + %16109 = stablehlo.real_dynamic_slice %16104, %c_24, %from_elements_5644, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5645 = tensor.dim %16109, %c0 : tensor + %16110 = arith.index_cast %dim_5645 : index to i64 + %from_elements_5646 = tensor.from_elements %16110 : tensor<1xi64> + %16111 = stablehlo.dynamic_reshape %16109, %from_elements_5646 : (tensor, tensor<1xi64>) -> tensor + %16112 = stablehlo.reshape %16088 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_5647 = tensor.dim %16111, %c0 : tensor + %16113 = arith.index_cast %dim_5647 : index to i64 + %from_elements_5648 = tensor.from_elements %16113, %c1_i64 : tensor<2xi64> + %16114 = stablehlo.dynamic_reshape %16111, %from_elements_5648 : (tensor, tensor<2xi64>) -> tensor + %dim_5649 = tensor.dim %16114, %c0 : tensor + %16115 = arith.index_cast %dim_5649 : index to i64 + %from_elements_5650 = tensor.from_elements %c1_i64, %16115, %c4096_i64 : tensor<3xi64> + %16116 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5650, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5651 = tensor.dim %16116, %c1 : tensor<1x?x4096xi64> + %16117 = arith.index_cast %dim_5651 : index to i64 + %from_elements_5652 = tensor.from_elements %c1_i64, %16117, %c4096_i64, %c1_i64 : tensor<4xi64> + %16118 = stablehlo.dynamic_reshape %16116, %from_elements_5652 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16119 = stablehlo.dynamic_broadcast_in_dim %16114, %from_elements_5650, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5653 = tensor.dim %16119, %c1 : tensor<1x?x4096xi64> + %16120 = arith.index_cast %dim_5653 : index to i64 + %from_elements_5654 = tensor.from_elements %c1_i64, %16120, %c4096_i64, %c1_i64 : tensor<4xi64> + %16121 = stablehlo.dynamic_reshape %16119, %from_elements_5654 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16122 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5650, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5655 = tensor.dim %16122, %c1 : tensor<1x?x4096xi64> + %16123 = arith.index_cast %dim_5655 : index to i64 + %from_elements_5656 = tensor.from_elements %c1_i64, %16123, %c4096_i64, %c1_i64 : tensor<4xi64> + %16124 = stablehlo.dynamic_reshape %16122, %from_elements_5656 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16125 = stablehlo.concatenate %16118, %16121, %16124, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16126 = "stablehlo.gather"(%16112, %16125) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16127 = shape.shape_of %16126 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16128 = shape.num_elements %16127 : tensor<3xindex> -> index + %16129 = stablehlo.compute_reshape_shape %16128, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16130 = stablehlo.dynamic_reshape %16126, %16129 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16131 = stablehlo.dot %16130, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16132 = stablehlo.logistic %16131 : tensor + %16133 = shape.shape_of %16132 : tensor -> tensor<2xindex> + %16134 = shape.shape_of %16131 : tensor -> tensor<2xindex> + %16135 = shape.cstr_broadcastable %16133, %16134 : tensor<2xindex>, tensor<2xindex> + %16136 = shape.assuming %16135 -> (tensor) { + %19688 = shape.broadcast %16133, %16134 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16132, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16131, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16137 = shape.shape_of %16136 : tensor -> tensor<2xindex> + %16138 = shape.cstr_broadcastable %16137, %16134 : tensor<2xindex>, tensor<2xindex> + %16139 = shape.assuming %16138 -> (tensor) { + %19688 = shape.broadcast %16137, %16134 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16136, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16131, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16140 = stablehlo.dot %16139, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %16141 = stablehlo.reshape %16096 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_5657 = tensor.dim %16111, %c0 : tensor + %16142 = arith.index_cast %dim_5657 : index to i64 + %from_elements_5658 = tensor.from_elements %16142, %c1_i64 : tensor<2xi64> + %16143 = stablehlo.dynamic_reshape %16111, %from_elements_5658 : (tensor, tensor<2xi64>) -> tensor + %dim_5659 = tensor.dim %16108, %c0 : tensor + %16144 = arith.index_cast %dim_5659 : index to i64 + %from_elements_5660 = tensor.from_elements %16144, %c1_i64 : tensor<2xi64> + %16145 = stablehlo.dynamic_reshape %16108, %from_elements_5660 : (tensor, tensor<2xi64>) -> tensor + %16146 = stablehlo.concatenate %16143, %16145, dim = 1 : (tensor, tensor) -> tensor + %16147 = "stablehlo.gather"(%16141, %16146) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16148 = shape.shape_of %16140 : tensor -> tensor<2xindex> + %16149 = shape.shape_of %16147 : tensor -> tensor<2xindex> + %16150 = shape.cstr_broadcastable %16148, %16149 : tensor<2xindex>, tensor<2xindex> + %16151 = shape.assuming %16150 -> (tensor) { + %19688 = shape.broadcast %16148, %16149 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16140, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16147, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16152 = shape.shape_of %16151 : tensor -> tensor<2xindex> + %16153 = stablehlo.dynamic_broadcast_in_dim %16151, %16152, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16154 = stablehlo.dynamic_broadcast_in_dim %213, %16152, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16155 = stablehlo.multiply %16153, %16154 : tensor + %dim_5661 = tensor.dim %16114, %c0 : tensor + %16156 = arith.index_cast %dim_5661 : index to i64 + %dim_5662 = tensor.dim %16151, %c0 : tensor + %16157 = arith.index_cast %dim_5662 : index to i64 + %16158 = arith.maxsi %16156, %16157 : i64 + %16159 = arith.index_cast %16158 : i64 to index + %from_elements_5663 = tensor.from_elements %16159, %c4096 : tensor<2xindex> + %16160 = stablehlo.dynamic_broadcast_in_dim %16114, %from_elements_5663, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5664 = tensor.dim %16160, %c0 : tensor + %16161 = arith.index_cast %dim_5664 : index to i64 + %from_elements_5665 = tensor.from_elements %16161, %c4096_i64 : tensor<2xi64> + %16162 = stablehlo.real_dynamic_slice %16155, %c_22, %from_elements_5665, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5666 = tensor.from_elements %16161, %c4096_i64, %c1_i64 : tensor<3xi64> + %16163 = stablehlo.dynamic_reshape %16160, %from_elements_5666 : (tensor, tensor<3xi64>) -> tensor + %16164 = stablehlo.dynamic_iota %from_elements_5666, dim = 1 : (tensor<3xi64>) -> tensor + %16165 = stablehlo.concatenate %16163, %16164, dim = 2 : (tensor, tensor) -> tensor + %16166 = "stablehlo.scatter"(%cst_2, %16165, %16162) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16167 = stablehlo.slice %16101 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16168 = stablehlo.reshape %16167 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16169 = stablehlo.custom_call @byteir.non_zero(%16168) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5667 = tensor.dim %16169, %c0 : tensor + %16170 = arith.index_cast %dim_5667 : index to i64 + %from_elements_5668 = tensor.from_elements %16170, %c1_i64 : tensor<2xi64> + %16171 = stablehlo.real_dynamic_slice %16169, %c_22, %from_elements_5668, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5669 = tensor.dim %16171, %c0 : tensor + %16172 = arith.index_cast %dim_5669 : index to i64 + %from_elements_5670 = tensor.from_elements %16172 : tensor<1xi64> + %16173 = stablehlo.dynamic_reshape %16171, %from_elements_5670 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5671 = tensor.from_elements %16170, %c2_i64 : tensor<2xi64> + %16174 = stablehlo.real_dynamic_slice %16169, %c_24, %from_elements_5671, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5672 = tensor.dim %16174, %c0 : tensor + %16175 = arith.index_cast %dim_5672 : index to i64 + %from_elements_5673 = tensor.from_elements %16175 : tensor<1xi64> + %16176 = stablehlo.dynamic_reshape %16174, %from_elements_5673 : (tensor, tensor<1xi64>) -> tensor + %dim_5674 = tensor.dim %16176, %c0 : tensor + %16177 = arith.index_cast %dim_5674 : index to i64 + %from_elements_5675 = tensor.from_elements %16177, %c1_i64 : tensor<2xi64> + %16178 = stablehlo.dynamic_reshape %16176, %from_elements_5675 : (tensor, tensor<2xi64>) -> tensor + %dim_5676 = tensor.dim %16178, %c0 : tensor + %16179 = arith.index_cast %dim_5676 : index to i64 + %from_elements_5677 = tensor.from_elements %c1_i64, %16179, %c4096_i64 : tensor<3xi64> + %16180 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5677, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5678 = tensor.dim %16180, %c1 : tensor<1x?x4096xi64> + %16181 = arith.index_cast %dim_5678 : index to i64 + %from_elements_5679 = tensor.from_elements %c1_i64, %16181, %c4096_i64, %c1_i64 : tensor<4xi64> + %16182 = stablehlo.dynamic_reshape %16180, %from_elements_5679 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16183 = stablehlo.dynamic_broadcast_in_dim %16178, %from_elements_5677, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5680 = tensor.dim %16183, %c1 : tensor<1x?x4096xi64> + %16184 = arith.index_cast %dim_5680 : index to i64 + %from_elements_5681 = tensor.from_elements %c1_i64, %16184, %c4096_i64, %c1_i64 : tensor<4xi64> + %16185 = stablehlo.dynamic_reshape %16183, %from_elements_5681 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16186 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5677, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5682 = tensor.dim %16186, %c1 : tensor<1x?x4096xi64> + %16187 = arith.index_cast %dim_5682 : index to i64 + %from_elements_5683 = tensor.from_elements %c1_i64, %16187, %c4096_i64, %c1_i64 : tensor<4xi64> + %16188 = stablehlo.dynamic_reshape %16186, %from_elements_5683 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16189 = stablehlo.concatenate %16182, %16185, %16188, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16190 = "stablehlo.gather"(%16112, %16189) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16191 = shape.shape_of %16190 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16192 = shape.num_elements %16191 : tensor<3xindex> -> index + %16193 = stablehlo.compute_reshape_shape %16192, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16194 = stablehlo.dynamic_reshape %16190, %16193 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16195 = stablehlo.dot %16194, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16196 = stablehlo.logistic %16195 : tensor + %16197 = shape.shape_of %16196 : tensor -> tensor<2xindex> + %16198 = shape.shape_of %16195 : tensor -> tensor<2xindex> + %16199 = shape.cstr_broadcastable %16197, %16198 : tensor<2xindex>, tensor<2xindex> + %16200 = shape.assuming %16199 -> (tensor) { + %19688 = shape.broadcast %16197, %16198 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16196, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16195, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16201 = shape.shape_of %16200 : tensor -> tensor<2xindex> + %16202 = shape.cstr_broadcastable %16201, %16198 : tensor<2xindex>, tensor<2xindex> + %16203 = shape.assuming %16202 -> (tensor) { + %19688 = shape.broadcast %16201, %16198 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16200, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16195, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16204 = stablehlo.dot %16203, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5684 = tensor.dim %16176, %c0 : tensor + %16205 = arith.index_cast %dim_5684 : index to i64 + %from_elements_5685 = tensor.from_elements %16205, %c1_i64 : tensor<2xi64> + %16206 = stablehlo.dynamic_reshape %16176, %from_elements_5685 : (tensor, tensor<2xi64>) -> tensor + %dim_5686 = tensor.dim %16173, %c0 : tensor + %16207 = arith.index_cast %dim_5686 : index to i64 + %from_elements_5687 = tensor.from_elements %16207, %c1_i64 : tensor<2xi64> + %16208 = stablehlo.dynamic_reshape %16173, %from_elements_5687 : (tensor, tensor<2xi64>) -> tensor + %16209 = stablehlo.concatenate %16206, %16208, dim = 1 : (tensor, tensor) -> tensor + %16210 = "stablehlo.gather"(%16141, %16209) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16211 = shape.shape_of %16204 : tensor -> tensor<2xindex> + %16212 = shape.shape_of %16210 : tensor -> tensor<2xindex> + %16213 = shape.cstr_broadcastable %16211, %16212 : tensor<2xindex>, tensor<2xindex> + %16214 = shape.assuming %16213 -> (tensor) { + %19688 = shape.broadcast %16211, %16212 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16204, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16210, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16215 = shape.shape_of %16214 : tensor -> tensor<2xindex> + %16216 = stablehlo.dynamic_broadcast_in_dim %16214, %16215, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16217 = stablehlo.dynamic_broadcast_in_dim %213, %16215, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16218 = stablehlo.multiply %16216, %16217 : tensor + %dim_5688 = tensor.dim %16178, %c0 : tensor + %16219 = arith.index_cast %dim_5688 : index to i64 + %dim_5689 = tensor.dim %16214, %c0 : tensor + %16220 = arith.index_cast %dim_5689 : index to i64 + %16221 = arith.maxsi %16219, %16220 : i64 + %16222 = arith.index_cast %16221 : i64 to index + %from_elements_5690 = tensor.from_elements %16222, %c4096 : tensor<2xindex> + %16223 = stablehlo.dynamic_broadcast_in_dim %16178, %from_elements_5690, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5691 = tensor.dim %16223, %c0 : tensor + %16224 = arith.index_cast %dim_5691 : index to i64 + %from_elements_5692 = tensor.from_elements %16224, %c4096_i64 : tensor<2xi64> + %16225 = stablehlo.real_dynamic_slice %16218, %c_22, %from_elements_5692, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5693 = tensor.from_elements %16224, %c4096_i64, %c1_i64 : tensor<3xi64> + %16226 = stablehlo.dynamic_reshape %16223, %from_elements_5693 : (tensor, tensor<3xi64>) -> tensor + %16227 = stablehlo.dynamic_iota %from_elements_5693, dim = 1 : (tensor<3xi64>) -> tensor + %16228 = stablehlo.concatenate %16226, %16227, dim = 2 : (tensor, tensor) -> tensor + %16229 = "stablehlo.scatter"(%16166, %16228, %16225) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16230 = stablehlo.slice %16101 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16231 = stablehlo.reshape %16230 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16232 = stablehlo.custom_call @byteir.non_zero(%16231) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5694 = tensor.dim %16232, %c0 : tensor + %16233 = arith.index_cast %dim_5694 : index to i64 + %from_elements_5695 = tensor.from_elements %16233, %c1_i64 : tensor<2xi64> + %16234 = stablehlo.real_dynamic_slice %16232, %c_22, %from_elements_5695, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5696 = tensor.dim %16234, %c0 : tensor + %16235 = arith.index_cast %dim_5696 : index to i64 + %from_elements_5697 = tensor.from_elements %16235 : tensor<1xi64> + %16236 = stablehlo.dynamic_reshape %16234, %from_elements_5697 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5698 = tensor.from_elements %16233, %c2_i64 : tensor<2xi64> + %16237 = stablehlo.real_dynamic_slice %16232, %c_24, %from_elements_5698, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5699 = tensor.dim %16237, %c0 : tensor + %16238 = arith.index_cast %dim_5699 : index to i64 + %from_elements_5700 = tensor.from_elements %16238 : tensor<1xi64> + %16239 = stablehlo.dynamic_reshape %16237, %from_elements_5700 : (tensor, tensor<1xi64>) -> tensor + %dim_5701 = tensor.dim %16239, %c0 : tensor + %16240 = arith.index_cast %dim_5701 : index to i64 + %from_elements_5702 = tensor.from_elements %16240, %c1_i64 : tensor<2xi64> + %16241 = stablehlo.dynamic_reshape %16239, %from_elements_5702 : (tensor, tensor<2xi64>) -> tensor + %dim_5703 = tensor.dim %16241, %c0 : tensor + %16242 = arith.index_cast %dim_5703 : index to i64 + %from_elements_5704 = tensor.from_elements %c1_i64, %16242, %c4096_i64 : tensor<3xi64> + %16243 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5704, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5705 = tensor.dim %16243, %c1 : tensor<1x?x4096xi64> + %16244 = arith.index_cast %dim_5705 : index to i64 + %from_elements_5706 = tensor.from_elements %c1_i64, %16244, %c4096_i64, %c1_i64 : tensor<4xi64> + %16245 = stablehlo.dynamic_reshape %16243, %from_elements_5706 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16246 = stablehlo.dynamic_broadcast_in_dim %16241, %from_elements_5704, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5707 = tensor.dim %16246, %c1 : tensor<1x?x4096xi64> + %16247 = arith.index_cast %dim_5707 : index to i64 + %from_elements_5708 = tensor.from_elements %c1_i64, %16247, %c4096_i64, %c1_i64 : tensor<4xi64> + %16248 = stablehlo.dynamic_reshape %16246, %from_elements_5708 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16249 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5704, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5709 = tensor.dim %16249, %c1 : tensor<1x?x4096xi64> + %16250 = arith.index_cast %dim_5709 : index to i64 + %from_elements_5710 = tensor.from_elements %c1_i64, %16250, %c4096_i64, %c1_i64 : tensor<4xi64> + %16251 = stablehlo.dynamic_reshape %16249, %from_elements_5710 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16252 = stablehlo.concatenate %16245, %16248, %16251, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16253 = "stablehlo.gather"(%16112, %16252) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16254 = shape.shape_of %16253 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16255 = shape.num_elements %16254 : tensor<3xindex> -> index + %16256 = stablehlo.compute_reshape_shape %16255, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16257 = stablehlo.dynamic_reshape %16253, %16256 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16258 = stablehlo.dot %16257, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16259 = stablehlo.logistic %16258 : tensor + %16260 = shape.shape_of %16259 : tensor -> tensor<2xindex> + %16261 = shape.shape_of %16258 : tensor -> tensor<2xindex> + %16262 = shape.cstr_broadcastable %16260, %16261 : tensor<2xindex>, tensor<2xindex> + %16263 = shape.assuming %16262 -> (tensor) { + %19688 = shape.broadcast %16260, %16261 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16259, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16258, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16264 = shape.shape_of %16263 : tensor -> tensor<2xindex> + %16265 = shape.cstr_broadcastable %16264, %16261 : tensor<2xindex>, tensor<2xindex> + %16266 = shape.assuming %16265 -> (tensor) { + %19688 = shape.broadcast %16264, %16261 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16263, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16258, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16267 = stablehlo.dot %16266, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5711 = tensor.dim %16239, %c0 : tensor + %16268 = arith.index_cast %dim_5711 : index to i64 + %from_elements_5712 = tensor.from_elements %16268, %c1_i64 : tensor<2xi64> + %16269 = stablehlo.dynamic_reshape %16239, %from_elements_5712 : (tensor, tensor<2xi64>) -> tensor + %dim_5713 = tensor.dim %16236, %c0 : tensor + %16270 = arith.index_cast %dim_5713 : index to i64 + %from_elements_5714 = tensor.from_elements %16270, %c1_i64 : tensor<2xi64> + %16271 = stablehlo.dynamic_reshape %16236, %from_elements_5714 : (tensor, tensor<2xi64>) -> tensor + %16272 = stablehlo.concatenate %16269, %16271, dim = 1 : (tensor, tensor) -> tensor + %16273 = "stablehlo.gather"(%16141, %16272) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16274 = shape.shape_of %16267 : tensor -> tensor<2xindex> + %16275 = shape.shape_of %16273 : tensor -> tensor<2xindex> + %16276 = shape.cstr_broadcastable %16274, %16275 : tensor<2xindex>, tensor<2xindex> + %16277 = shape.assuming %16276 -> (tensor) { + %19688 = shape.broadcast %16274, %16275 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16267, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16273, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16278 = shape.shape_of %16277 : tensor -> tensor<2xindex> + %16279 = stablehlo.dynamic_broadcast_in_dim %16277, %16278, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16280 = stablehlo.dynamic_broadcast_in_dim %213, %16278, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16281 = stablehlo.multiply %16279, %16280 : tensor + %dim_5715 = tensor.dim %16241, %c0 : tensor + %16282 = arith.index_cast %dim_5715 : index to i64 + %dim_5716 = tensor.dim %16277, %c0 : tensor + %16283 = arith.index_cast %dim_5716 : index to i64 + %16284 = arith.maxsi %16282, %16283 : i64 + %16285 = arith.index_cast %16284 : i64 to index + %from_elements_5717 = tensor.from_elements %16285, %c4096 : tensor<2xindex> + %16286 = stablehlo.dynamic_broadcast_in_dim %16241, %from_elements_5717, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5718 = tensor.dim %16286, %c0 : tensor + %16287 = arith.index_cast %dim_5718 : index to i64 + %from_elements_5719 = tensor.from_elements %16287, %c4096_i64 : tensor<2xi64> + %16288 = stablehlo.real_dynamic_slice %16281, %c_22, %from_elements_5719, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5720 = tensor.from_elements %16287, %c4096_i64, %c1_i64 : tensor<3xi64> + %16289 = stablehlo.dynamic_reshape %16286, %from_elements_5720 : (tensor, tensor<3xi64>) -> tensor + %16290 = stablehlo.dynamic_iota %from_elements_5720, dim = 1 : (tensor<3xi64>) -> tensor + %16291 = stablehlo.concatenate %16289, %16290, dim = 2 : (tensor, tensor) -> tensor + %16292 = "stablehlo.scatter"(%16229, %16291, %16288) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16293 = stablehlo.slice %16101 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16294 = stablehlo.reshape %16293 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16295 = stablehlo.custom_call @byteir.non_zero(%16294) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5721 = tensor.dim %16295, %c0 : tensor + %16296 = arith.index_cast %dim_5721 : index to i64 + %from_elements_5722 = tensor.from_elements %16296, %c1_i64 : tensor<2xi64> + %16297 = stablehlo.real_dynamic_slice %16295, %c_22, %from_elements_5722, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5723 = tensor.dim %16297, %c0 : tensor + %16298 = arith.index_cast %dim_5723 : index to i64 + %from_elements_5724 = tensor.from_elements %16298 : tensor<1xi64> + %16299 = stablehlo.dynamic_reshape %16297, %from_elements_5724 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5725 = tensor.from_elements %16296, %c2_i64 : tensor<2xi64> + %16300 = stablehlo.real_dynamic_slice %16295, %c_24, %from_elements_5725, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5726 = tensor.dim %16300, %c0 : tensor + %16301 = arith.index_cast %dim_5726 : index to i64 + %from_elements_5727 = tensor.from_elements %16301 : tensor<1xi64> + %16302 = stablehlo.dynamic_reshape %16300, %from_elements_5727 : (tensor, tensor<1xi64>) -> tensor + %dim_5728 = tensor.dim %16302, %c0 : tensor + %16303 = arith.index_cast %dim_5728 : index to i64 + %from_elements_5729 = tensor.from_elements %16303, %c1_i64 : tensor<2xi64> + %16304 = stablehlo.dynamic_reshape %16302, %from_elements_5729 : (tensor, tensor<2xi64>) -> tensor + %dim_5730 = tensor.dim %16304, %c0 : tensor + %16305 = arith.index_cast %dim_5730 : index to i64 + %from_elements_5731 = tensor.from_elements %c1_i64, %16305, %c4096_i64 : tensor<3xi64> + %16306 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5731, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5732 = tensor.dim %16306, %c1 : tensor<1x?x4096xi64> + %16307 = arith.index_cast %dim_5732 : index to i64 + %from_elements_5733 = tensor.from_elements %c1_i64, %16307, %c4096_i64, %c1_i64 : tensor<4xi64> + %16308 = stablehlo.dynamic_reshape %16306, %from_elements_5733 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16309 = stablehlo.dynamic_broadcast_in_dim %16304, %from_elements_5731, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5734 = tensor.dim %16309, %c1 : tensor<1x?x4096xi64> + %16310 = arith.index_cast %dim_5734 : index to i64 + %from_elements_5735 = tensor.from_elements %c1_i64, %16310, %c4096_i64, %c1_i64 : tensor<4xi64> + %16311 = stablehlo.dynamic_reshape %16309, %from_elements_5735 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16312 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5731, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5736 = tensor.dim %16312, %c1 : tensor<1x?x4096xi64> + %16313 = arith.index_cast %dim_5736 : index to i64 + %from_elements_5737 = tensor.from_elements %c1_i64, %16313, %c4096_i64, %c1_i64 : tensor<4xi64> + %16314 = stablehlo.dynamic_reshape %16312, %from_elements_5737 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16315 = stablehlo.concatenate %16308, %16311, %16314, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16316 = "stablehlo.gather"(%16112, %16315) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16317 = shape.shape_of %16316 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16318 = shape.num_elements %16317 : tensor<3xindex> -> index + %16319 = stablehlo.compute_reshape_shape %16318, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16320 = stablehlo.dynamic_reshape %16316, %16319 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16321 = stablehlo.dot %16320, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16322 = stablehlo.logistic %16321 : tensor + %16323 = shape.shape_of %16322 : tensor -> tensor<2xindex> + %16324 = shape.shape_of %16321 : tensor -> tensor<2xindex> + %16325 = shape.cstr_broadcastable %16323, %16324 : tensor<2xindex>, tensor<2xindex> + %16326 = shape.assuming %16325 -> (tensor) { + %19688 = shape.broadcast %16323, %16324 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16322, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16321, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16327 = shape.shape_of %16326 : tensor -> tensor<2xindex> + %16328 = shape.cstr_broadcastable %16327, %16324 : tensor<2xindex>, tensor<2xindex> + %16329 = shape.assuming %16328 -> (tensor) { + %19688 = shape.broadcast %16327, %16324 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16326, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16321, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16330 = stablehlo.dot %16329, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5738 = tensor.dim %16302, %c0 : tensor + %16331 = arith.index_cast %dim_5738 : index to i64 + %from_elements_5739 = tensor.from_elements %16331, %c1_i64 : tensor<2xi64> + %16332 = stablehlo.dynamic_reshape %16302, %from_elements_5739 : (tensor, tensor<2xi64>) -> tensor + %dim_5740 = tensor.dim %16299, %c0 : tensor + %16333 = arith.index_cast %dim_5740 : index to i64 + %from_elements_5741 = tensor.from_elements %16333, %c1_i64 : tensor<2xi64> + %16334 = stablehlo.dynamic_reshape %16299, %from_elements_5741 : (tensor, tensor<2xi64>) -> tensor + %16335 = stablehlo.concatenate %16332, %16334, dim = 1 : (tensor, tensor) -> tensor + %16336 = "stablehlo.gather"(%16141, %16335) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16337 = shape.shape_of %16330 : tensor -> tensor<2xindex> + %16338 = shape.shape_of %16336 : tensor -> tensor<2xindex> + %16339 = shape.cstr_broadcastable %16337, %16338 : tensor<2xindex>, tensor<2xindex> + %16340 = shape.assuming %16339 -> (tensor) { + %19688 = shape.broadcast %16337, %16338 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16330, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16336, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16341 = shape.shape_of %16340 : tensor -> tensor<2xindex> + %16342 = stablehlo.dynamic_broadcast_in_dim %16340, %16341, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16343 = stablehlo.dynamic_broadcast_in_dim %213, %16341, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16344 = stablehlo.multiply %16342, %16343 : tensor + %dim_5742 = tensor.dim %16304, %c0 : tensor + %16345 = arith.index_cast %dim_5742 : index to i64 + %dim_5743 = tensor.dim %16340, %c0 : tensor + %16346 = arith.index_cast %dim_5743 : index to i64 + %16347 = arith.maxsi %16345, %16346 : i64 + %16348 = arith.index_cast %16347 : i64 to index + %from_elements_5744 = tensor.from_elements %16348, %c4096 : tensor<2xindex> + %16349 = stablehlo.dynamic_broadcast_in_dim %16304, %from_elements_5744, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5745 = tensor.dim %16349, %c0 : tensor + %16350 = arith.index_cast %dim_5745 : index to i64 + %from_elements_5746 = tensor.from_elements %16350, %c4096_i64 : tensor<2xi64> + %16351 = stablehlo.real_dynamic_slice %16344, %c_22, %from_elements_5746, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5747 = tensor.from_elements %16350, %c4096_i64, %c1_i64 : tensor<3xi64> + %16352 = stablehlo.dynamic_reshape %16349, %from_elements_5747 : (tensor, tensor<3xi64>) -> tensor + %16353 = stablehlo.dynamic_iota %from_elements_5747, dim = 1 : (tensor<3xi64>) -> tensor + %16354 = stablehlo.concatenate %16352, %16353, dim = 2 : (tensor, tensor) -> tensor + %16355 = "stablehlo.scatter"(%16292, %16354, %16351) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16356 = stablehlo.slice %16101 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16357 = stablehlo.reshape %16356 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16358 = stablehlo.custom_call @byteir.non_zero(%16357) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5748 = tensor.dim %16358, %c0 : tensor + %16359 = arith.index_cast %dim_5748 : index to i64 + %from_elements_5749 = tensor.from_elements %16359, %c1_i64 : tensor<2xi64> + %16360 = stablehlo.real_dynamic_slice %16358, %c_22, %from_elements_5749, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5750 = tensor.dim %16360, %c0 : tensor + %16361 = arith.index_cast %dim_5750 : index to i64 + %from_elements_5751 = tensor.from_elements %16361 : tensor<1xi64> + %16362 = stablehlo.dynamic_reshape %16360, %from_elements_5751 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5752 = tensor.from_elements %16359, %c2_i64 : tensor<2xi64> + %16363 = stablehlo.real_dynamic_slice %16358, %c_24, %from_elements_5752, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5753 = tensor.dim %16363, %c0 : tensor + %16364 = arith.index_cast %dim_5753 : index to i64 + %from_elements_5754 = tensor.from_elements %16364 : tensor<1xi64> + %16365 = stablehlo.dynamic_reshape %16363, %from_elements_5754 : (tensor, tensor<1xi64>) -> tensor + %dim_5755 = tensor.dim %16365, %c0 : tensor + %16366 = arith.index_cast %dim_5755 : index to i64 + %from_elements_5756 = tensor.from_elements %16366, %c1_i64 : tensor<2xi64> + %16367 = stablehlo.dynamic_reshape %16365, %from_elements_5756 : (tensor, tensor<2xi64>) -> tensor + %dim_5757 = tensor.dim %16367, %c0 : tensor + %16368 = arith.index_cast %dim_5757 : index to i64 + %from_elements_5758 = tensor.from_elements %c1_i64, %16368, %c4096_i64 : tensor<3xi64> + %16369 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5758, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5759 = tensor.dim %16369, %c1 : tensor<1x?x4096xi64> + %16370 = arith.index_cast %dim_5759 : index to i64 + %from_elements_5760 = tensor.from_elements %c1_i64, %16370, %c4096_i64, %c1_i64 : tensor<4xi64> + %16371 = stablehlo.dynamic_reshape %16369, %from_elements_5760 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16372 = stablehlo.dynamic_broadcast_in_dim %16367, %from_elements_5758, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5761 = tensor.dim %16372, %c1 : tensor<1x?x4096xi64> + %16373 = arith.index_cast %dim_5761 : index to i64 + %from_elements_5762 = tensor.from_elements %c1_i64, %16373, %c4096_i64, %c1_i64 : tensor<4xi64> + %16374 = stablehlo.dynamic_reshape %16372, %from_elements_5762 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16375 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5758, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5763 = tensor.dim %16375, %c1 : tensor<1x?x4096xi64> + %16376 = arith.index_cast %dim_5763 : index to i64 + %from_elements_5764 = tensor.from_elements %c1_i64, %16376, %c4096_i64, %c1_i64 : tensor<4xi64> + %16377 = stablehlo.dynamic_reshape %16375, %from_elements_5764 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16378 = stablehlo.concatenate %16371, %16374, %16377, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16379 = "stablehlo.gather"(%16112, %16378) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16380 = shape.shape_of %16379 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16381 = shape.num_elements %16380 : tensor<3xindex> -> index + %16382 = stablehlo.compute_reshape_shape %16381, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16383 = stablehlo.dynamic_reshape %16379, %16382 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16384 = stablehlo.dot %16383, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16385 = stablehlo.logistic %16384 : tensor + %16386 = shape.shape_of %16385 : tensor -> tensor<2xindex> + %16387 = shape.shape_of %16384 : tensor -> tensor<2xindex> + %16388 = shape.cstr_broadcastable %16386, %16387 : tensor<2xindex>, tensor<2xindex> + %16389 = shape.assuming %16388 -> (tensor) { + %19688 = shape.broadcast %16386, %16387 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16385, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16384, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16390 = shape.shape_of %16389 : tensor -> tensor<2xindex> + %16391 = shape.cstr_broadcastable %16390, %16387 : tensor<2xindex>, tensor<2xindex> + %16392 = shape.assuming %16391 -> (tensor) { + %19688 = shape.broadcast %16390, %16387 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16389, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16384, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16393 = stablehlo.dot %16392, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5765 = tensor.dim %16365, %c0 : tensor + %16394 = arith.index_cast %dim_5765 : index to i64 + %from_elements_5766 = tensor.from_elements %16394, %c1_i64 : tensor<2xi64> + %16395 = stablehlo.dynamic_reshape %16365, %from_elements_5766 : (tensor, tensor<2xi64>) -> tensor + %dim_5767 = tensor.dim %16362, %c0 : tensor + %16396 = arith.index_cast %dim_5767 : index to i64 + %from_elements_5768 = tensor.from_elements %16396, %c1_i64 : tensor<2xi64> + %16397 = stablehlo.dynamic_reshape %16362, %from_elements_5768 : (tensor, tensor<2xi64>) -> tensor + %16398 = stablehlo.concatenate %16395, %16397, dim = 1 : (tensor, tensor) -> tensor + %16399 = "stablehlo.gather"(%16141, %16398) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16400 = shape.shape_of %16393 : tensor -> tensor<2xindex> + %16401 = shape.shape_of %16399 : tensor -> tensor<2xindex> + %16402 = shape.cstr_broadcastable %16400, %16401 : tensor<2xindex>, tensor<2xindex> + %16403 = shape.assuming %16402 -> (tensor) { + %19688 = shape.broadcast %16400, %16401 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16393, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16399, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16404 = shape.shape_of %16403 : tensor -> tensor<2xindex> + %16405 = stablehlo.dynamic_broadcast_in_dim %16403, %16404, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16406 = stablehlo.dynamic_broadcast_in_dim %213, %16404, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16407 = stablehlo.multiply %16405, %16406 : tensor + %dim_5769 = tensor.dim %16367, %c0 : tensor + %16408 = arith.index_cast %dim_5769 : index to i64 + %dim_5770 = tensor.dim %16403, %c0 : tensor + %16409 = arith.index_cast %dim_5770 : index to i64 + %16410 = arith.maxsi %16408, %16409 : i64 + %16411 = arith.index_cast %16410 : i64 to index + %from_elements_5771 = tensor.from_elements %16411, %c4096 : tensor<2xindex> + %16412 = stablehlo.dynamic_broadcast_in_dim %16367, %from_elements_5771, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5772 = tensor.dim %16412, %c0 : tensor + %16413 = arith.index_cast %dim_5772 : index to i64 + %from_elements_5773 = tensor.from_elements %16413, %c4096_i64 : tensor<2xi64> + %16414 = stablehlo.real_dynamic_slice %16407, %c_22, %from_elements_5773, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5774 = tensor.from_elements %16413, %c4096_i64, %c1_i64 : tensor<3xi64> + %16415 = stablehlo.dynamic_reshape %16412, %from_elements_5774 : (tensor, tensor<3xi64>) -> tensor + %16416 = stablehlo.dynamic_iota %from_elements_5774, dim = 1 : (tensor<3xi64>) -> tensor + %16417 = stablehlo.concatenate %16415, %16416, dim = 2 : (tensor, tensor) -> tensor + %16418 = "stablehlo.scatter"(%16355, %16417, %16414) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16419 = stablehlo.slice %16101 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16420 = stablehlo.reshape %16419 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16421 = stablehlo.custom_call @byteir.non_zero(%16420) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5775 = tensor.dim %16421, %c0 : tensor + %16422 = arith.index_cast %dim_5775 : index to i64 + %from_elements_5776 = tensor.from_elements %16422, %c1_i64 : tensor<2xi64> + %16423 = stablehlo.real_dynamic_slice %16421, %c_22, %from_elements_5776, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5777 = tensor.dim %16423, %c0 : tensor + %16424 = arith.index_cast %dim_5777 : index to i64 + %from_elements_5778 = tensor.from_elements %16424 : tensor<1xi64> + %16425 = stablehlo.dynamic_reshape %16423, %from_elements_5778 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5779 = tensor.from_elements %16422, %c2_i64 : tensor<2xi64> + %16426 = stablehlo.real_dynamic_slice %16421, %c_24, %from_elements_5779, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5780 = tensor.dim %16426, %c0 : tensor + %16427 = arith.index_cast %dim_5780 : index to i64 + %from_elements_5781 = tensor.from_elements %16427 : tensor<1xi64> + %16428 = stablehlo.dynamic_reshape %16426, %from_elements_5781 : (tensor, tensor<1xi64>) -> tensor + %dim_5782 = tensor.dim %16428, %c0 : tensor + %16429 = arith.index_cast %dim_5782 : index to i64 + %from_elements_5783 = tensor.from_elements %16429, %c1_i64 : tensor<2xi64> + %16430 = stablehlo.dynamic_reshape %16428, %from_elements_5783 : (tensor, tensor<2xi64>) -> tensor + %dim_5784 = tensor.dim %16430, %c0 : tensor + %16431 = arith.index_cast %dim_5784 : index to i64 + %from_elements_5785 = tensor.from_elements %c1_i64, %16431, %c4096_i64 : tensor<3xi64> + %16432 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5785, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5786 = tensor.dim %16432, %c1 : tensor<1x?x4096xi64> + %16433 = arith.index_cast %dim_5786 : index to i64 + %from_elements_5787 = tensor.from_elements %c1_i64, %16433, %c4096_i64, %c1_i64 : tensor<4xi64> + %16434 = stablehlo.dynamic_reshape %16432, %from_elements_5787 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16435 = stablehlo.dynamic_broadcast_in_dim %16430, %from_elements_5785, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5788 = tensor.dim %16435, %c1 : tensor<1x?x4096xi64> + %16436 = arith.index_cast %dim_5788 : index to i64 + %from_elements_5789 = tensor.from_elements %c1_i64, %16436, %c4096_i64, %c1_i64 : tensor<4xi64> + %16437 = stablehlo.dynamic_reshape %16435, %from_elements_5789 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16438 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5785, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5790 = tensor.dim %16438, %c1 : tensor<1x?x4096xi64> + %16439 = arith.index_cast %dim_5790 : index to i64 + %from_elements_5791 = tensor.from_elements %c1_i64, %16439, %c4096_i64, %c1_i64 : tensor<4xi64> + %16440 = stablehlo.dynamic_reshape %16438, %from_elements_5791 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16441 = stablehlo.concatenate %16434, %16437, %16440, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16442 = "stablehlo.gather"(%16112, %16441) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16443 = shape.shape_of %16442 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16444 = shape.num_elements %16443 : tensor<3xindex> -> index + %16445 = stablehlo.compute_reshape_shape %16444, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16446 = stablehlo.dynamic_reshape %16442, %16445 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16447 = stablehlo.dot %16446, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16448 = stablehlo.logistic %16447 : tensor + %16449 = shape.shape_of %16448 : tensor -> tensor<2xindex> + %16450 = shape.shape_of %16447 : tensor -> tensor<2xindex> + %16451 = shape.cstr_broadcastable %16449, %16450 : tensor<2xindex>, tensor<2xindex> + %16452 = shape.assuming %16451 -> (tensor) { + %19688 = shape.broadcast %16449, %16450 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16448, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16447, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16453 = shape.shape_of %16452 : tensor -> tensor<2xindex> + %16454 = shape.cstr_broadcastable %16453, %16450 : tensor<2xindex>, tensor<2xindex> + %16455 = shape.assuming %16454 -> (tensor) { + %19688 = shape.broadcast %16453, %16450 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16452, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16447, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16456 = stablehlo.dot %16455, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5792 = tensor.dim %16428, %c0 : tensor + %16457 = arith.index_cast %dim_5792 : index to i64 + %from_elements_5793 = tensor.from_elements %16457, %c1_i64 : tensor<2xi64> + %16458 = stablehlo.dynamic_reshape %16428, %from_elements_5793 : (tensor, tensor<2xi64>) -> tensor + %dim_5794 = tensor.dim %16425, %c0 : tensor + %16459 = arith.index_cast %dim_5794 : index to i64 + %from_elements_5795 = tensor.from_elements %16459, %c1_i64 : tensor<2xi64> + %16460 = stablehlo.dynamic_reshape %16425, %from_elements_5795 : (tensor, tensor<2xi64>) -> tensor + %16461 = stablehlo.concatenate %16458, %16460, dim = 1 : (tensor, tensor) -> tensor + %16462 = "stablehlo.gather"(%16141, %16461) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16463 = shape.shape_of %16456 : tensor -> tensor<2xindex> + %16464 = shape.shape_of %16462 : tensor -> tensor<2xindex> + %16465 = shape.cstr_broadcastable %16463, %16464 : tensor<2xindex>, tensor<2xindex> + %16466 = shape.assuming %16465 -> (tensor) { + %19688 = shape.broadcast %16463, %16464 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16456, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16462, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16467 = shape.shape_of %16466 : tensor -> tensor<2xindex> + %16468 = stablehlo.dynamic_broadcast_in_dim %16466, %16467, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16469 = stablehlo.dynamic_broadcast_in_dim %213, %16467, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16470 = stablehlo.multiply %16468, %16469 : tensor + %dim_5796 = tensor.dim %16430, %c0 : tensor + %16471 = arith.index_cast %dim_5796 : index to i64 + %dim_5797 = tensor.dim %16466, %c0 : tensor + %16472 = arith.index_cast %dim_5797 : index to i64 + %16473 = arith.maxsi %16471, %16472 : i64 + %16474 = arith.index_cast %16473 : i64 to index + %from_elements_5798 = tensor.from_elements %16474, %c4096 : tensor<2xindex> + %16475 = stablehlo.dynamic_broadcast_in_dim %16430, %from_elements_5798, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5799 = tensor.dim %16475, %c0 : tensor + %16476 = arith.index_cast %dim_5799 : index to i64 + %from_elements_5800 = tensor.from_elements %16476, %c4096_i64 : tensor<2xi64> + %16477 = stablehlo.real_dynamic_slice %16470, %c_22, %from_elements_5800, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5801 = tensor.from_elements %16476, %c4096_i64, %c1_i64 : tensor<3xi64> + %16478 = stablehlo.dynamic_reshape %16475, %from_elements_5801 : (tensor, tensor<3xi64>) -> tensor + %16479 = stablehlo.dynamic_iota %from_elements_5801, dim = 1 : (tensor<3xi64>) -> tensor + %16480 = stablehlo.concatenate %16478, %16479, dim = 2 : (tensor, tensor) -> tensor + %16481 = "stablehlo.scatter"(%16418, %16480, %16477) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16482 = stablehlo.slice %16101 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16483 = stablehlo.reshape %16482 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16484 = stablehlo.custom_call @byteir.non_zero(%16483) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5802 = tensor.dim %16484, %c0 : tensor + %16485 = arith.index_cast %dim_5802 : index to i64 + %from_elements_5803 = tensor.from_elements %16485, %c1_i64 : tensor<2xi64> + %16486 = stablehlo.real_dynamic_slice %16484, %c_22, %from_elements_5803, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5804 = tensor.dim %16486, %c0 : tensor + %16487 = arith.index_cast %dim_5804 : index to i64 + %from_elements_5805 = tensor.from_elements %16487 : tensor<1xi64> + %16488 = stablehlo.dynamic_reshape %16486, %from_elements_5805 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5806 = tensor.from_elements %16485, %c2_i64 : tensor<2xi64> + %16489 = stablehlo.real_dynamic_slice %16484, %c_24, %from_elements_5806, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5807 = tensor.dim %16489, %c0 : tensor + %16490 = arith.index_cast %dim_5807 : index to i64 + %from_elements_5808 = tensor.from_elements %16490 : tensor<1xi64> + %16491 = stablehlo.dynamic_reshape %16489, %from_elements_5808 : (tensor, tensor<1xi64>) -> tensor + %dim_5809 = tensor.dim %16491, %c0 : tensor + %16492 = arith.index_cast %dim_5809 : index to i64 + %from_elements_5810 = tensor.from_elements %16492, %c1_i64 : tensor<2xi64> + %16493 = stablehlo.dynamic_reshape %16491, %from_elements_5810 : (tensor, tensor<2xi64>) -> tensor + %dim_5811 = tensor.dim %16493, %c0 : tensor + %16494 = arith.index_cast %dim_5811 : index to i64 + %from_elements_5812 = tensor.from_elements %c1_i64, %16494, %c4096_i64 : tensor<3xi64> + %16495 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5812, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5813 = tensor.dim %16495, %c1 : tensor<1x?x4096xi64> + %16496 = arith.index_cast %dim_5813 : index to i64 + %from_elements_5814 = tensor.from_elements %c1_i64, %16496, %c4096_i64, %c1_i64 : tensor<4xi64> + %16497 = stablehlo.dynamic_reshape %16495, %from_elements_5814 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16498 = stablehlo.dynamic_broadcast_in_dim %16493, %from_elements_5812, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5815 = tensor.dim %16498, %c1 : tensor<1x?x4096xi64> + %16499 = arith.index_cast %dim_5815 : index to i64 + %from_elements_5816 = tensor.from_elements %c1_i64, %16499, %c4096_i64, %c1_i64 : tensor<4xi64> + %16500 = stablehlo.dynamic_reshape %16498, %from_elements_5816 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16501 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5812, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5817 = tensor.dim %16501, %c1 : tensor<1x?x4096xi64> + %16502 = arith.index_cast %dim_5817 : index to i64 + %from_elements_5818 = tensor.from_elements %c1_i64, %16502, %c4096_i64, %c1_i64 : tensor<4xi64> + %16503 = stablehlo.dynamic_reshape %16501, %from_elements_5818 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16504 = stablehlo.concatenate %16497, %16500, %16503, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16505 = "stablehlo.gather"(%16112, %16504) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16506 = shape.shape_of %16505 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16507 = shape.num_elements %16506 : tensor<3xindex> -> index + %16508 = stablehlo.compute_reshape_shape %16507, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16509 = stablehlo.dynamic_reshape %16505, %16508 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16510 = stablehlo.dot %16509, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16511 = stablehlo.logistic %16510 : tensor + %16512 = shape.shape_of %16511 : tensor -> tensor<2xindex> + %16513 = shape.shape_of %16510 : tensor -> tensor<2xindex> + %16514 = shape.cstr_broadcastable %16512, %16513 : tensor<2xindex>, tensor<2xindex> + %16515 = shape.assuming %16514 -> (tensor) { + %19688 = shape.broadcast %16512, %16513 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16511, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16510, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16516 = shape.shape_of %16515 : tensor -> tensor<2xindex> + %16517 = shape.cstr_broadcastable %16516, %16513 : tensor<2xindex>, tensor<2xindex> + %16518 = shape.assuming %16517 -> (tensor) { + %19688 = shape.broadcast %16516, %16513 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16515, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16510, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16519 = stablehlo.dot %16518, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5819 = tensor.dim %16491, %c0 : tensor + %16520 = arith.index_cast %dim_5819 : index to i64 + %from_elements_5820 = tensor.from_elements %16520, %c1_i64 : tensor<2xi64> + %16521 = stablehlo.dynamic_reshape %16491, %from_elements_5820 : (tensor, tensor<2xi64>) -> tensor + %dim_5821 = tensor.dim %16488, %c0 : tensor + %16522 = arith.index_cast %dim_5821 : index to i64 + %from_elements_5822 = tensor.from_elements %16522, %c1_i64 : tensor<2xi64> + %16523 = stablehlo.dynamic_reshape %16488, %from_elements_5822 : (tensor, tensor<2xi64>) -> tensor + %16524 = stablehlo.concatenate %16521, %16523, dim = 1 : (tensor, tensor) -> tensor + %16525 = "stablehlo.gather"(%16141, %16524) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16526 = shape.shape_of %16519 : tensor -> tensor<2xindex> + %16527 = shape.shape_of %16525 : tensor -> tensor<2xindex> + %16528 = shape.cstr_broadcastable %16526, %16527 : tensor<2xindex>, tensor<2xindex> + %16529 = shape.assuming %16528 -> (tensor) { + %19688 = shape.broadcast %16526, %16527 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16519, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16525, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16530 = shape.shape_of %16529 : tensor -> tensor<2xindex> + %16531 = stablehlo.dynamic_broadcast_in_dim %16529, %16530, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16532 = stablehlo.dynamic_broadcast_in_dim %213, %16530, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16533 = stablehlo.multiply %16531, %16532 : tensor + %dim_5823 = tensor.dim %16493, %c0 : tensor + %16534 = arith.index_cast %dim_5823 : index to i64 + %dim_5824 = tensor.dim %16529, %c0 : tensor + %16535 = arith.index_cast %dim_5824 : index to i64 + %16536 = arith.maxsi %16534, %16535 : i64 + %16537 = arith.index_cast %16536 : i64 to index + %from_elements_5825 = tensor.from_elements %16537, %c4096 : tensor<2xindex> + %16538 = stablehlo.dynamic_broadcast_in_dim %16493, %from_elements_5825, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5826 = tensor.dim %16538, %c0 : tensor + %16539 = arith.index_cast %dim_5826 : index to i64 + %from_elements_5827 = tensor.from_elements %16539, %c4096_i64 : tensor<2xi64> + %16540 = stablehlo.real_dynamic_slice %16533, %c_22, %from_elements_5827, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5828 = tensor.from_elements %16539, %c4096_i64, %c1_i64 : tensor<3xi64> + %16541 = stablehlo.dynamic_reshape %16538, %from_elements_5828 : (tensor, tensor<3xi64>) -> tensor + %16542 = stablehlo.dynamic_iota %from_elements_5828, dim = 1 : (tensor<3xi64>) -> tensor + %16543 = stablehlo.concatenate %16541, %16542, dim = 2 : (tensor, tensor) -> tensor + %16544 = "stablehlo.scatter"(%16481, %16543, %16540) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16545 = stablehlo.slice %16101 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16546 = stablehlo.reshape %16545 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16547 = stablehlo.custom_call @byteir.non_zero(%16546) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5829 = tensor.dim %16547, %c0 : tensor + %16548 = arith.index_cast %dim_5829 : index to i64 + %from_elements_5830 = tensor.from_elements %16548, %c1_i64 : tensor<2xi64> + %16549 = stablehlo.real_dynamic_slice %16547, %c_22, %from_elements_5830, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5831 = tensor.dim %16549, %c0 : tensor + %16550 = arith.index_cast %dim_5831 : index to i64 + %from_elements_5832 = tensor.from_elements %16550 : tensor<1xi64> + %16551 = stablehlo.dynamic_reshape %16549, %from_elements_5832 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5833 = tensor.from_elements %16548, %c2_i64 : tensor<2xi64> + %16552 = stablehlo.real_dynamic_slice %16547, %c_24, %from_elements_5833, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5834 = tensor.dim %16552, %c0 : tensor + %16553 = arith.index_cast %dim_5834 : index to i64 + %from_elements_5835 = tensor.from_elements %16553 : tensor<1xi64> + %16554 = stablehlo.dynamic_reshape %16552, %from_elements_5835 : (tensor, tensor<1xi64>) -> tensor + %dim_5836 = tensor.dim %16554, %c0 : tensor + %16555 = arith.index_cast %dim_5836 : index to i64 + %from_elements_5837 = tensor.from_elements %16555, %c1_i64 : tensor<2xi64> + %16556 = stablehlo.dynamic_reshape %16554, %from_elements_5837 : (tensor, tensor<2xi64>) -> tensor + %dim_5838 = tensor.dim %16556, %c0 : tensor + %16557 = arith.index_cast %dim_5838 : index to i64 + %from_elements_5839 = tensor.from_elements %c1_i64, %16557, %c4096_i64 : tensor<3xi64> + %16558 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5839, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5840 = tensor.dim %16558, %c1 : tensor<1x?x4096xi64> + %16559 = arith.index_cast %dim_5840 : index to i64 + %from_elements_5841 = tensor.from_elements %c1_i64, %16559, %c4096_i64, %c1_i64 : tensor<4xi64> + %16560 = stablehlo.dynamic_reshape %16558, %from_elements_5841 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16561 = stablehlo.dynamic_broadcast_in_dim %16556, %from_elements_5839, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5842 = tensor.dim %16561, %c1 : tensor<1x?x4096xi64> + %16562 = arith.index_cast %dim_5842 : index to i64 + %from_elements_5843 = tensor.from_elements %c1_i64, %16562, %c4096_i64, %c1_i64 : tensor<4xi64> + %16563 = stablehlo.dynamic_reshape %16561, %from_elements_5843 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16564 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5839, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5844 = tensor.dim %16564, %c1 : tensor<1x?x4096xi64> + %16565 = arith.index_cast %dim_5844 : index to i64 + %from_elements_5845 = tensor.from_elements %c1_i64, %16565, %c4096_i64, %c1_i64 : tensor<4xi64> + %16566 = stablehlo.dynamic_reshape %16564, %from_elements_5845 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16567 = stablehlo.concatenate %16560, %16563, %16566, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16568 = "stablehlo.gather"(%16112, %16567) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16569 = shape.shape_of %16568 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16570 = shape.num_elements %16569 : tensor<3xindex> -> index + %16571 = stablehlo.compute_reshape_shape %16570, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16572 = stablehlo.dynamic_reshape %16568, %16571 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16573 = stablehlo.dot %16572, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16574 = stablehlo.logistic %16573 : tensor + %16575 = shape.shape_of %16574 : tensor -> tensor<2xindex> + %16576 = shape.shape_of %16573 : tensor -> tensor<2xindex> + %16577 = shape.cstr_broadcastable %16575, %16576 : tensor<2xindex>, tensor<2xindex> + %16578 = shape.assuming %16577 -> (tensor) { + %19688 = shape.broadcast %16575, %16576 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16574, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16573, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16579 = shape.shape_of %16578 : tensor -> tensor<2xindex> + %16580 = shape.cstr_broadcastable %16579, %16576 : tensor<2xindex>, tensor<2xindex> + %16581 = shape.assuming %16580 -> (tensor) { + %19688 = shape.broadcast %16579, %16576 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16578, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16573, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16582 = stablehlo.dot %16581, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5846 = tensor.dim %16554, %c0 : tensor + %16583 = arith.index_cast %dim_5846 : index to i64 + %from_elements_5847 = tensor.from_elements %16583, %c1_i64 : tensor<2xi64> + %16584 = stablehlo.dynamic_reshape %16554, %from_elements_5847 : (tensor, tensor<2xi64>) -> tensor + %dim_5848 = tensor.dim %16551, %c0 : tensor + %16585 = arith.index_cast %dim_5848 : index to i64 + %from_elements_5849 = tensor.from_elements %16585, %c1_i64 : tensor<2xi64> + %16586 = stablehlo.dynamic_reshape %16551, %from_elements_5849 : (tensor, tensor<2xi64>) -> tensor + %16587 = stablehlo.concatenate %16584, %16586, dim = 1 : (tensor, tensor) -> tensor + %16588 = "stablehlo.gather"(%16141, %16587) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16589 = shape.shape_of %16582 : tensor -> tensor<2xindex> + %16590 = shape.shape_of %16588 : tensor -> tensor<2xindex> + %16591 = shape.cstr_broadcastable %16589, %16590 : tensor<2xindex>, tensor<2xindex> + %16592 = shape.assuming %16591 -> (tensor) { + %19688 = shape.broadcast %16589, %16590 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16582, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16588, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16593 = shape.shape_of %16592 : tensor -> tensor<2xindex> + %16594 = stablehlo.dynamic_broadcast_in_dim %16592, %16593, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16595 = stablehlo.dynamic_broadcast_in_dim %213, %16593, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16596 = stablehlo.multiply %16594, %16595 : tensor + %dim_5850 = tensor.dim %16556, %c0 : tensor + %16597 = arith.index_cast %dim_5850 : index to i64 + %dim_5851 = tensor.dim %16592, %c0 : tensor + %16598 = arith.index_cast %dim_5851 : index to i64 + %16599 = arith.maxsi %16597, %16598 : i64 + %16600 = arith.index_cast %16599 : i64 to index + %from_elements_5852 = tensor.from_elements %16600, %c4096 : tensor<2xindex> + %16601 = stablehlo.dynamic_broadcast_in_dim %16556, %from_elements_5852, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5853 = tensor.dim %16601, %c0 : tensor + %16602 = arith.index_cast %dim_5853 : index to i64 + %from_elements_5854 = tensor.from_elements %16602, %c4096_i64 : tensor<2xi64> + %16603 = stablehlo.real_dynamic_slice %16596, %c_22, %from_elements_5854, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5855 = tensor.from_elements %16602, %c4096_i64, %c1_i64 : tensor<3xi64> + %16604 = stablehlo.dynamic_reshape %16601, %from_elements_5855 : (tensor, tensor<3xi64>) -> tensor + %16605 = stablehlo.dynamic_iota %from_elements_5855, dim = 1 : (tensor<3xi64>) -> tensor + %16606 = stablehlo.concatenate %16604, %16605, dim = 2 : (tensor, tensor) -> tensor + %16607 = "stablehlo.scatter"(%16544, %16606, %16603) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16608 = stablehlo.reshape %16607 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %16609 = stablehlo.add %16074, %16608 : tensor<3x1x4096xf32> + %16610 = stablehlo.broadcast_in_dim %16609, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %16611 = stablehlo.power %16610, %15 : tensor<3x1x4096xf32> + %16612 = stablehlo.reduce(%16611 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %16613 = stablehlo.reshape %16612 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %16614 = stablehlo.broadcast_in_dim %16613, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %16615 = stablehlo.divide %16614, %21 : tensor<3x1x1xf32> + %16616 = stablehlo.broadcast_in_dim %16615, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %16617 = stablehlo.add %16616, %25 : tensor<3x1x1xf32> + %16618 = stablehlo.rsqrt %16617 : tensor<3x1x1xf32> + %16619 = stablehlo.broadcast_in_dim %16618, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %16620 = stablehlo.multiply %16610, %16619 : tensor<3x1x4096xf32> + %16621 = stablehlo.broadcast_in_dim %16620, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %16622 = stablehlo.multiply %16621, %31 : tensor<3x1x4096xf32> + %16623 = stablehlo.reshape %16622 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %16624 = stablehlo.dot %16623, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %16625 = stablehlo.reshape %16624 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %16626 = stablehlo.dot %16623, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %16627 = stablehlo.reshape %16626 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %16628 = stablehlo.reshape %16625 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %16629 = stablehlo.transpose %16628, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %16630 = stablehlo.reshape %16627 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %16631 = stablehlo.transpose %16630, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %16632 = stablehlo.slice %arg54 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %16633 = stablehlo.slice %arg55 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %16634 = "stablehlo.gather"(%16632, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %16635 = stablehlo.reshape %16634 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %16636 = "stablehlo.gather"(%16633, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %16637 = stablehlo.reshape %16636 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %16638 = stablehlo.broadcast_in_dim %16629, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %16639 = stablehlo.broadcast_in_dim %16635, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %16640 = stablehlo.multiply %16638, %16639 : tensor<3x32x1x128xf32> + %16641 = stablehlo.slice %16629 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %16642 = stablehlo.slice %16629 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %16643 = stablehlo.negate %16642 : tensor<3x32x1x64xf32> + %16644 = stablehlo.concatenate %16643, %16641, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %16645 = stablehlo.broadcast_in_dim %16644, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %16646 = stablehlo.broadcast_in_dim %16637, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %16647 = stablehlo.multiply %16645, %16646 : tensor<3x32x1x128xf32> + %16648 = stablehlo.add %16640, %16647 : tensor<3x32x1x128xf32> + %16649 = stablehlo.broadcast_in_dim %16631, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %16650 = stablehlo.broadcast_in_dim %16635, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %16651 = stablehlo.multiply %16649, %16650 : tensor<3x8x1x128xf32> + %16652 = stablehlo.slice %16631 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %16653 = stablehlo.slice %16631 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %16654 = stablehlo.negate %16653 : tensor<3x8x1x64xf32> + %16655 = stablehlo.concatenate %16654, %16652, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %16656 = stablehlo.broadcast_in_dim %16655, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %16657 = stablehlo.broadcast_in_dim %16637, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %16658 = stablehlo.multiply %16656, %16657 : tensor<3x8x1x128xf32> + %16659 = stablehlo.add %16651, %16658 : tensor<3x8x1x128xf32> + %16660 = stablehlo.concatenate %arg119, %16659, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %16661 = stablehlo.concatenate %arg120, %16631, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %16662 = stablehlo.reshape %16660 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %16663 = stablehlo.broadcast_in_dim %16662, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %16664 = stablehlo.reshape %16663 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %16665 = stablehlo.reshape %16661 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %16666 = stablehlo.broadcast_in_dim %16665, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %16667 = stablehlo.reshape %16666 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %16668 = stablehlo.transpose %16664, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %16669 = stablehlo.reshape %16648 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %16670 = stablehlo.reshape %16668 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %16671 = stablehlo.broadcast_in_dim %16670, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %16672 = stablehlo.dot_general %16669, %16671, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %16673 = stablehlo.reshape %16672 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %16674 = stablehlo.broadcast_in_dim %16673, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %16675 = stablehlo.divide %16674, %89 : tensor<3x32x1x8xf32> + %16676 = stablehlo.custom_call @byteir.softmax(%16675) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %16677 = stablehlo.reshape %16676 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %16678 = stablehlo.reshape %16667 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %16679 = stablehlo.broadcast_in_dim %16678, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %16680 = stablehlo.dot_general %16677, %16679, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %16681 = stablehlo.reshape %16680 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %16682 = stablehlo.transpose %16681, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %16683 = stablehlo.reshape %16682 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %16684 = stablehlo.reshape %16683 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %16685 = stablehlo.dot %16684, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %16686 = stablehlo.reshape %16685 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %16687 = stablehlo.add %16609, %16686 : tensor<3x1x4096xf32> + %16688 = stablehlo.broadcast_in_dim %16687, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %16689 = stablehlo.power %16688, %15 : tensor<3x1x4096xf32> + %16690 = stablehlo.reduce(%16689 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %16691 = stablehlo.reshape %16690 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %16692 = stablehlo.broadcast_in_dim %16691, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %16693 = stablehlo.divide %16692, %21 : tensor<3x1x1xf32> + %16694 = stablehlo.broadcast_in_dim %16693, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %16695 = stablehlo.add %16694, %25 : tensor<3x1x1xf32> + %16696 = stablehlo.rsqrt %16695 : tensor<3x1x1xf32> + %16697 = stablehlo.broadcast_in_dim %16696, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %16698 = stablehlo.multiply %16688, %16697 : tensor<3x1x4096xf32> + %16699 = stablehlo.broadcast_in_dim %16698, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %16700 = stablehlo.multiply %16699, %31 : tensor<3x1x4096xf32> + %16701 = stablehlo.reshape %16700 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %16702 = stablehlo.dot %16701, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %16703 = stablehlo.custom_call @byteir.softmax(%16702) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %16704:2 = stablehlo.custom_call @byteir.top_k(%16703) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %16705 = stablehlo.reduce(%16704#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %16706 = stablehlo.reshape %16705 : (tensor<3xf32>) -> tensor<3x1xf32> + %16707 = stablehlo.broadcast_in_dim %16704#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %16708 = stablehlo.broadcast_in_dim %16706, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %16709 = stablehlo.divide %16707, %16708 : tensor<3x2xf32> + %16710 = stablehlo.reshape %16704#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %16711 = stablehlo.broadcast_in_dim %16710, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %16712 = stablehlo.compare EQ, %16711, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %16713 = stablehlo.convert %16712 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %16714 = stablehlo.transpose %16713, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %16715 = stablehlo.slice %16714 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16716 = stablehlo.reshape %16715 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16717 = stablehlo.custom_call @byteir.non_zero(%16716) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5856 = tensor.dim %16717, %c0 : tensor + %16718 = arith.index_cast %dim_5856 : index to i64 + %from_elements_5857 = tensor.from_elements %16718, %c1_i64 : tensor<2xi64> + %16719 = stablehlo.real_dynamic_slice %16717, %c_22, %from_elements_5857, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5858 = tensor.dim %16719, %c0 : tensor + %16720 = arith.index_cast %dim_5858 : index to i64 + %from_elements_5859 = tensor.from_elements %16720 : tensor<1xi64> + %16721 = stablehlo.dynamic_reshape %16719, %from_elements_5859 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5860 = tensor.from_elements %16718, %c2_i64 : tensor<2xi64> + %16722 = stablehlo.real_dynamic_slice %16717, %c_24, %from_elements_5860, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5861 = tensor.dim %16722, %c0 : tensor + %16723 = arith.index_cast %dim_5861 : index to i64 + %from_elements_5862 = tensor.from_elements %16723 : tensor<1xi64> + %16724 = stablehlo.dynamic_reshape %16722, %from_elements_5862 : (tensor, tensor<1xi64>) -> tensor + %16725 = stablehlo.reshape %16701 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_5863 = tensor.dim %16724, %c0 : tensor + %16726 = arith.index_cast %dim_5863 : index to i64 + %from_elements_5864 = tensor.from_elements %16726, %c1_i64 : tensor<2xi64> + %16727 = stablehlo.dynamic_reshape %16724, %from_elements_5864 : (tensor, tensor<2xi64>) -> tensor + %dim_5865 = tensor.dim %16727, %c0 : tensor + %16728 = arith.index_cast %dim_5865 : index to i64 + %from_elements_5866 = tensor.from_elements %c1_i64, %16728, %c4096_i64 : tensor<3xi64> + %16729 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5866, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5867 = tensor.dim %16729, %c1 : tensor<1x?x4096xi64> + %16730 = arith.index_cast %dim_5867 : index to i64 + %from_elements_5868 = tensor.from_elements %c1_i64, %16730, %c4096_i64, %c1_i64 : tensor<4xi64> + %16731 = stablehlo.dynamic_reshape %16729, %from_elements_5868 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16732 = stablehlo.dynamic_broadcast_in_dim %16727, %from_elements_5866, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5869 = tensor.dim %16732, %c1 : tensor<1x?x4096xi64> + %16733 = arith.index_cast %dim_5869 : index to i64 + %from_elements_5870 = tensor.from_elements %c1_i64, %16733, %c4096_i64, %c1_i64 : tensor<4xi64> + %16734 = stablehlo.dynamic_reshape %16732, %from_elements_5870 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16735 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5866, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5871 = tensor.dim %16735, %c1 : tensor<1x?x4096xi64> + %16736 = arith.index_cast %dim_5871 : index to i64 + %from_elements_5872 = tensor.from_elements %c1_i64, %16736, %c4096_i64, %c1_i64 : tensor<4xi64> + %16737 = stablehlo.dynamic_reshape %16735, %from_elements_5872 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16738 = stablehlo.concatenate %16731, %16734, %16737, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16739 = "stablehlo.gather"(%16725, %16738) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16740 = shape.shape_of %16739 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16741 = shape.num_elements %16740 : tensor<3xindex> -> index + %16742 = stablehlo.compute_reshape_shape %16741, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16743 = stablehlo.dynamic_reshape %16739, %16742 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16744 = stablehlo.dot %16743, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16745 = stablehlo.logistic %16744 : tensor + %16746 = shape.shape_of %16745 : tensor -> tensor<2xindex> + %16747 = shape.shape_of %16744 : tensor -> tensor<2xindex> + %16748 = shape.cstr_broadcastable %16746, %16747 : tensor<2xindex>, tensor<2xindex> + %16749 = shape.assuming %16748 -> (tensor) { + %19688 = shape.broadcast %16746, %16747 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16745, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16744, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16750 = shape.shape_of %16749 : tensor -> tensor<2xindex> + %16751 = shape.cstr_broadcastable %16750, %16747 : tensor<2xindex>, tensor<2xindex> + %16752 = shape.assuming %16751 -> (tensor) { + %19688 = shape.broadcast %16750, %16747 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16749, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16744, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16753 = stablehlo.dot %16752, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %16754 = stablehlo.reshape %16709 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_5873 = tensor.dim %16724, %c0 : tensor + %16755 = arith.index_cast %dim_5873 : index to i64 + %from_elements_5874 = tensor.from_elements %16755, %c1_i64 : tensor<2xi64> + %16756 = stablehlo.dynamic_reshape %16724, %from_elements_5874 : (tensor, tensor<2xi64>) -> tensor + %dim_5875 = tensor.dim %16721, %c0 : tensor + %16757 = arith.index_cast %dim_5875 : index to i64 + %from_elements_5876 = tensor.from_elements %16757, %c1_i64 : tensor<2xi64> + %16758 = stablehlo.dynamic_reshape %16721, %from_elements_5876 : (tensor, tensor<2xi64>) -> tensor + %16759 = stablehlo.concatenate %16756, %16758, dim = 1 : (tensor, tensor) -> tensor + %16760 = "stablehlo.gather"(%16754, %16759) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16761 = shape.shape_of %16753 : tensor -> tensor<2xindex> + %16762 = shape.shape_of %16760 : tensor -> tensor<2xindex> + %16763 = shape.cstr_broadcastable %16761, %16762 : tensor<2xindex>, tensor<2xindex> + %16764 = shape.assuming %16763 -> (tensor) { + %19688 = shape.broadcast %16761, %16762 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16753, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16760, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16765 = shape.shape_of %16764 : tensor -> tensor<2xindex> + %16766 = stablehlo.dynamic_broadcast_in_dim %16764, %16765, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16767 = stablehlo.dynamic_broadcast_in_dim %213, %16765, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16768 = stablehlo.multiply %16766, %16767 : tensor + %dim_5877 = tensor.dim %16727, %c0 : tensor + %16769 = arith.index_cast %dim_5877 : index to i64 + %dim_5878 = tensor.dim %16764, %c0 : tensor + %16770 = arith.index_cast %dim_5878 : index to i64 + %16771 = arith.maxsi %16769, %16770 : i64 + %16772 = arith.index_cast %16771 : i64 to index + %from_elements_5879 = tensor.from_elements %16772, %c4096 : tensor<2xindex> + %16773 = stablehlo.dynamic_broadcast_in_dim %16727, %from_elements_5879, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5880 = tensor.dim %16773, %c0 : tensor + %16774 = arith.index_cast %dim_5880 : index to i64 + %from_elements_5881 = tensor.from_elements %16774, %c4096_i64 : tensor<2xi64> + %16775 = stablehlo.real_dynamic_slice %16768, %c_22, %from_elements_5881, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5882 = tensor.from_elements %16774, %c4096_i64, %c1_i64 : tensor<3xi64> + %16776 = stablehlo.dynamic_reshape %16773, %from_elements_5882 : (tensor, tensor<3xi64>) -> tensor + %16777 = stablehlo.dynamic_iota %from_elements_5882, dim = 1 : (tensor<3xi64>) -> tensor + %16778 = stablehlo.concatenate %16776, %16777, dim = 2 : (tensor, tensor) -> tensor + %16779 = "stablehlo.scatter"(%cst_2, %16778, %16775) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16780 = stablehlo.slice %16714 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16781 = stablehlo.reshape %16780 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16782 = stablehlo.custom_call @byteir.non_zero(%16781) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5883 = tensor.dim %16782, %c0 : tensor + %16783 = arith.index_cast %dim_5883 : index to i64 + %from_elements_5884 = tensor.from_elements %16783, %c1_i64 : tensor<2xi64> + %16784 = stablehlo.real_dynamic_slice %16782, %c_22, %from_elements_5884, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5885 = tensor.dim %16784, %c0 : tensor + %16785 = arith.index_cast %dim_5885 : index to i64 + %from_elements_5886 = tensor.from_elements %16785 : tensor<1xi64> + %16786 = stablehlo.dynamic_reshape %16784, %from_elements_5886 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5887 = tensor.from_elements %16783, %c2_i64 : tensor<2xi64> + %16787 = stablehlo.real_dynamic_slice %16782, %c_24, %from_elements_5887, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5888 = tensor.dim %16787, %c0 : tensor + %16788 = arith.index_cast %dim_5888 : index to i64 + %from_elements_5889 = tensor.from_elements %16788 : tensor<1xi64> + %16789 = stablehlo.dynamic_reshape %16787, %from_elements_5889 : (tensor, tensor<1xi64>) -> tensor + %dim_5890 = tensor.dim %16789, %c0 : tensor + %16790 = arith.index_cast %dim_5890 : index to i64 + %from_elements_5891 = tensor.from_elements %16790, %c1_i64 : tensor<2xi64> + %16791 = stablehlo.dynamic_reshape %16789, %from_elements_5891 : (tensor, tensor<2xi64>) -> tensor + %dim_5892 = tensor.dim %16791, %c0 : tensor + %16792 = arith.index_cast %dim_5892 : index to i64 + %from_elements_5893 = tensor.from_elements %c1_i64, %16792, %c4096_i64 : tensor<3xi64> + %16793 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5893, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5894 = tensor.dim %16793, %c1 : tensor<1x?x4096xi64> + %16794 = arith.index_cast %dim_5894 : index to i64 + %from_elements_5895 = tensor.from_elements %c1_i64, %16794, %c4096_i64, %c1_i64 : tensor<4xi64> + %16795 = stablehlo.dynamic_reshape %16793, %from_elements_5895 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16796 = stablehlo.dynamic_broadcast_in_dim %16791, %from_elements_5893, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5896 = tensor.dim %16796, %c1 : tensor<1x?x4096xi64> + %16797 = arith.index_cast %dim_5896 : index to i64 + %from_elements_5897 = tensor.from_elements %c1_i64, %16797, %c4096_i64, %c1_i64 : tensor<4xi64> + %16798 = stablehlo.dynamic_reshape %16796, %from_elements_5897 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16799 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5893, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5898 = tensor.dim %16799, %c1 : tensor<1x?x4096xi64> + %16800 = arith.index_cast %dim_5898 : index to i64 + %from_elements_5899 = tensor.from_elements %c1_i64, %16800, %c4096_i64, %c1_i64 : tensor<4xi64> + %16801 = stablehlo.dynamic_reshape %16799, %from_elements_5899 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16802 = stablehlo.concatenate %16795, %16798, %16801, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16803 = "stablehlo.gather"(%16725, %16802) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16804 = shape.shape_of %16803 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16805 = shape.num_elements %16804 : tensor<3xindex> -> index + %16806 = stablehlo.compute_reshape_shape %16805, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16807 = stablehlo.dynamic_reshape %16803, %16806 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16808 = stablehlo.dot %16807, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16809 = stablehlo.logistic %16808 : tensor + %16810 = shape.shape_of %16809 : tensor -> tensor<2xindex> + %16811 = shape.shape_of %16808 : tensor -> tensor<2xindex> + %16812 = shape.cstr_broadcastable %16810, %16811 : tensor<2xindex>, tensor<2xindex> + %16813 = shape.assuming %16812 -> (tensor) { + %19688 = shape.broadcast %16810, %16811 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16809, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16808, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16814 = shape.shape_of %16813 : tensor -> tensor<2xindex> + %16815 = shape.cstr_broadcastable %16814, %16811 : tensor<2xindex>, tensor<2xindex> + %16816 = shape.assuming %16815 -> (tensor) { + %19688 = shape.broadcast %16814, %16811 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16813, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16808, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16817 = stablehlo.dot %16816, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5900 = tensor.dim %16789, %c0 : tensor + %16818 = arith.index_cast %dim_5900 : index to i64 + %from_elements_5901 = tensor.from_elements %16818, %c1_i64 : tensor<2xi64> + %16819 = stablehlo.dynamic_reshape %16789, %from_elements_5901 : (tensor, tensor<2xi64>) -> tensor + %dim_5902 = tensor.dim %16786, %c0 : tensor + %16820 = arith.index_cast %dim_5902 : index to i64 + %from_elements_5903 = tensor.from_elements %16820, %c1_i64 : tensor<2xi64> + %16821 = stablehlo.dynamic_reshape %16786, %from_elements_5903 : (tensor, tensor<2xi64>) -> tensor + %16822 = stablehlo.concatenate %16819, %16821, dim = 1 : (tensor, tensor) -> tensor + %16823 = "stablehlo.gather"(%16754, %16822) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16824 = shape.shape_of %16817 : tensor -> tensor<2xindex> + %16825 = shape.shape_of %16823 : tensor -> tensor<2xindex> + %16826 = shape.cstr_broadcastable %16824, %16825 : tensor<2xindex>, tensor<2xindex> + %16827 = shape.assuming %16826 -> (tensor) { + %19688 = shape.broadcast %16824, %16825 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16817, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16823, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16828 = shape.shape_of %16827 : tensor -> tensor<2xindex> + %16829 = stablehlo.dynamic_broadcast_in_dim %16827, %16828, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16830 = stablehlo.dynamic_broadcast_in_dim %213, %16828, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16831 = stablehlo.multiply %16829, %16830 : tensor + %dim_5904 = tensor.dim %16791, %c0 : tensor + %16832 = arith.index_cast %dim_5904 : index to i64 + %dim_5905 = tensor.dim %16827, %c0 : tensor + %16833 = arith.index_cast %dim_5905 : index to i64 + %16834 = arith.maxsi %16832, %16833 : i64 + %16835 = arith.index_cast %16834 : i64 to index + %from_elements_5906 = tensor.from_elements %16835, %c4096 : tensor<2xindex> + %16836 = stablehlo.dynamic_broadcast_in_dim %16791, %from_elements_5906, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5907 = tensor.dim %16836, %c0 : tensor + %16837 = arith.index_cast %dim_5907 : index to i64 + %from_elements_5908 = tensor.from_elements %16837, %c4096_i64 : tensor<2xi64> + %16838 = stablehlo.real_dynamic_slice %16831, %c_22, %from_elements_5908, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5909 = tensor.from_elements %16837, %c4096_i64, %c1_i64 : tensor<3xi64> + %16839 = stablehlo.dynamic_reshape %16836, %from_elements_5909 : (tensor, tensor<3xi64>) -> tensor + %16840 = stablehlo.dynamic_iota %from_elements_5909, dim = 1 : (tensor<3xi64>) -> tensor + %16841 = stablehlo.concatenate %16839, %16840, dim = 2 : (tensor, tensor) -> tensor + %16842 = "stablehlo.scatter"(%16779, %16841, %16838) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16843 = stablehlo.slice %16714 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16844 = stablehlo.reshape %16843 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16845 = stablehlo.custom_call @byteir.non_zero(%16844) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5910 = tensor.dim %16845, %c0 : tensor + %16846 = arith.index_cast %dim_5910 : index to i64 + %from_elements_5911 = tensor.from_elements %16846, %c1_i64 : tensor<2xi64> + %16847 = stablehlo.real_dynamic_slice %16845, %c_22, %from_elements_5911, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5912 = tensor.dim %16847, %c0 : tensor + %16848 = arith.index_cast %dim_5912 : index to i64 + %from_elements_5913 = tensor.from_elements %16848 : tensor<1xi64> + %16849 = stablehlo.dynamic_reshape %16847, %from_elements_5913 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5914 = tensor.from_elements %16846, %c2_i64 : tensor<2xi64> + %16850 = stablehlo.real_dynamic_slice %16845, %c_24, %from_elements_5914, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5915 = tensor.dim %16850, %c0 : tensor + %16851 = arith.index_cast %dim_5915 : index to i64 + %from_elements_5916 = tensor.from_elements %16851 : tensor<1xi64> + %16852 = stablehlo.dynamic_reshape %16850, %from_elements_5916 : (tensor, tensor<1xi64>) -> tensor + %dim_5917 = tensor.dim %16852, %c0 : tensor + %16853 = arith.index_cast %dim_5917 : index to i64 + %from_elements_5918 = tensor.from_elements %16853, %c1_i64 : tensor<2xi64> + %16854 = stablehlo.dynamic_reshape %16852, %from_elements_5918 : (tensor, tensor<2xi64>) -> tensor + %dim_5919 = tensor.dim %16854, %c0 : tensor + %16855 = arith.index_cast %dim_5919 : index to i64 + %from_elements_5920 = tensor.from_elements %c1_i64, %16855, %c4096_i64 : tensor<3xi64> + %16856 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5920, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5921 = tensor.dim %16856, %c1 : tensor<1x?x4096xi64> + %16857 = arith.index_cast %dim_5921 : index to i64 + %from_elements_5922 = tensor.from_elements %c1_i64, %16857, %c4096_i64, %c1_i64 : tensor<4xi64> + %16858 = stablehlo.dynamic_reshape %16856, %from_elements_5922 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16859 = stablehlo.dynamic_broadcast_in_dim %16854, %from_elements_5920, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5923 = tensor.dim %16859, %c1 : tensor<1x?x4096xi64> + %16860 = arith.index_cast %dim_5923 : index to i64 + %from_elements_5924 = tensor.from_elements %c1_i64, %16860, %c4096_i64, %c1_i64 : tensor<4xi64> + %16861 = stablehlo.dynamic_reshape %16859, %from_elements_5924 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16862 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5920, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5925 = tensor.dim %16862, %c1 : tensor<1x?x4096xi64> + %16863 = arith.index_cast %dim_5925 : index to i64 + %from_elements_5926 = tensor.from_elements %c1_i64, %16863, %c4096_i64, %c1_i64 : tensor<4xi64> + %16864 = stablehlo.dynamic_reshape %16862, %from_elements_5926 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16865 = stablehlo.concatenate %16858, %16861, %16864, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16866 = "stablehlo.gather"(%16725, %16865) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16867 = shape.shape_of %16866 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16868 = shape.num_elements %16867 : tensor<3xindex> -> index + %16869 = stablehlo.compute_reshape_shape %16868, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16870 = stablehlo.dynamic_reshape %16866, %16869 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16871 = stablehlo.dot %16870, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16872 = stablehlo.logistic %16871 : tensor + %16873 = shape.shape_of %16872 : tensor -> tensor<2xindex> + %16874 = shape.shape_of %16871 : tensor -> tensor<2xindex> + %16875 = shape.cstr_broadcastable %16873, %16874 : tensor<2xindex>, tensor<2xindex> + %16876 = shape.assuming %16875 -> (tensor) { + %19688 = shape.broadcast %16873, %16874 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16872, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16871, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16877 = shape.shape_of %16876 : tensor -> tensor<2xindex> + %16878 = shape.cstr_broadcastable %16877, %16874 : tensor<2xindex>, tensor<2xindex> + %16879 = shape.assuming %16878 -> (tensor) { + %19688 = shape.broadcast %16877, %16874 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16876, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16871, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16880 = stablehlo.dot %16879, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5927 = tensor.dim %16852, %c0 : tensor + %16881 = arith.index_cast %dim_5927 : index to i64 + %from_elements_5928 = tensor.from_elements %16881, %c1_i64 : tensor<2xi64> + %16882 = stablehlo.dynamic_reshape %16852, %from_elements_5928 : (tensor, tensor<2xi64>) -> tensor + %dim_5929 = tensor.dim %16849, %c0 : tensor + %16883 = arith.index_cast %dim_5929 : index to i64 + %from_elements_5930 = tensor.from_elements %16883, %c1_i64 : tensor<2xi64> + %16884 = stablehlo.dynamic_reshape %16849, %from_elements_5930 : (tensor, tensor<2xi64>) -> tensor + %16885 = stablehlo.concatenate %16882, %16884, dim = 1 : (tensor, tensor) -> tensor + %16886 = "stablehlo.gather"(%16754, %16885) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16887 = shape.shape_of %16880 : tensor -> tensor<2xindex> + %16888 = shape.shape_of %16886 : tensor -> tensor<2xindex> + %16889 = shape.cstr_broadcastable %16887, %16888 : tensor<2xindex>, tensor<2xindex> + %16890 = shape.assuming %16889 -> (tensor) { + %19688 = shape.broadcast %16887, %16888 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16880, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16886, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16891 = shape.shape_of %16890 : tensor -> tensor<2xindex> + %16892 = stablehlo.dynamic_broadcast_in_dim %16890, %16891, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16893 = stablehlo.dynamic_broadcast_in_dim %213, %16891, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16894 = stablehlo.multiply %16892, %16893 : tensor + %dim_5931 = tensor.dim %16854, %c0 : tensor + %16895 = arith.index_cast %dim_5931 : index to i64 + %dim_5932 = tensor.dim %16890, %c0 : tensor + %16896 = arith.index_cast %dim_5932 : index to i64 + %16897 = arith.maxsi %16895, %16896 : i64 + %16898 = arith.index_cast %16897 : i64 to index + %from_elements_5933 = tensor.from_elements %16898, %c4096 : tensor<2xindex> + %16899 = stablehlo.dynamic_broadcast_in_dim %16854, %from_elements_5933, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5934 = tensor.dim %16899, %c0 : tensor + %16900 = arith.index_cast %dim_5934 : index to i64 + %from_elements_5935 = tensor.from_elements %16900, %c4096_i64 : tensor<2xi64> + %16901 = stablehlo.real_dynamic_slice %16894, %c_22, %from_elements_5935, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5936 = tensor.from_elements %16900, %c4096_i64, %c1_i64 : tensor<3xi64> + %16902 = stablehlo.dynamic_reshape %16899, %from_elements_5936 : (tensor, tensor<3xi64>) -> tensor + %16903 = stablehlo.dynamic_iota %from_elements_5936, dim = 1 : (tensor<3xi64>) -> tensor + %16904 = stablehlo.concatenate %16902, %16903, dim = 2 : (tensor, tensor) -> tensor + %16905 = "stablehlo.scatter"(%16842, %16904, %16901) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16906 = stablehlo.slice %16714 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16907 = stablehlo.reshape %16906 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16908 = stablehlo.custom_call @byteir.non_zero(%16907) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5937 = tensor.dim %16908, %c0 : tensor + %16909 = arith.index_cast %dim_5937 : index to i64 + %from_elements_5938 = tensor.from_elements %16909, %c1_i64 : tensor<2xi64> + %16910 = stablehlo.real_dynamic_slice %16908, %c_22, %from_elements_5938, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5939 = tensor.dim %16910, %c0 : tensor + %16911 = arith.index_cast %dim_5939 : index to i64 + %from_elements_5940 = tensor.from_elements %16911 : tensor<1xi64> + %16912 = stablehlo.dynamic_reshape %16910, %from_elements_5940 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5941 = tensor.from_elements %16909, %c2_i64 : tensor<2xi64> + %16913 = stablehlo.real_dynamic_slice %16908, %c_24, %from_elements_5941, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5942 = tensor.dim %16913, %c0 : tensor + %16914 = arith.index_cast %dim_5942 : index to i64 + %from_elements_5943 = tensor.from_elements %16914 : tensor<1xi64> + %16915 = stablehlo.dynamic_reshape %16913, %from_elements_5943 : (tensor, tensor<1xi64>) -> tensor + %dim_5944 = tensor.dim %16915, %c0 : tensor + %16916 = arith.index_cast %dim_5944 : index to i64 + %from_elements_5945 = tensor.from_elements %16916, %c1_i64 : tensor<2xi64> + %16917 = stablehlo.dynamic_reshape %16915, %from_elements_5945 : (tensor, tensor<2xi64>) -> tensor + %dim_5946 = tensor.dim %16917, %c0 : tensor + %16918 = arith.index_cast %dim_5946 : index to i64 + %from_elements_5947 = tensor.from_elements %c1_i64, %16918, %c4096_i64 : tensor<3xi64> + %16919 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5947, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5948 = tensor.dim %16919, %c1 : tensor<1x?x4096xi64> + %16920 = arith.index_cast %dim_5948 : index to i64 + %from_elements_5949 = tensor.from_elements %c1_i64, %16920, %c4096_i64, %c1_i64 : tensor<4xi64> + %16921 = stablehlo.dynamic_reshape %16919, %from_elements_5949 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16922 = stablehlo.dynamic_broadcast_in_dim %16917, %from_elements_5947, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5950 = tensor.dim %16922, %c1 : tensor<1x?x4096xi64> + %16923 = arith.index_cast %dim_5950 : index to i64 + %from_elements_5951 = tensor.from_elements %c1_i64, %16923, %c4096_i64, %c1_i64 : tensor<4xi64> + %16924 = stablehlo.dynamic_reshape %16922, %from_elements_5951 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16925 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5947, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5952 = tensor.dim %16925, %c1 : tensor<1x?x4096xi64> + %16926 = arith.index_cast %dim_5952 : index to i64 + %from_elements_5953 = tensor.from_elements %c1_i64, %16926, %c4096_i64, %c1_i64 : tensor<4xi64> + %16927 = stablehlo.dynamic_reshape %16925, %from_elements_5953 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16928 = stablehlo.concatenate %16921, %16924, %16927, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16929 = "stablehlo.gather"(%16725, %16928) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16930 = shape.shape_of %16929 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16931 = shape.num_elements %16930 : tensor<3xindex> -> index + %16932 = stablehlo.compute_reshape_shape %16931, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16933 = stablehlo.dynamic_reshape %16929, %16932 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16934 = stablehlo.dot %16933, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16935 = stablehlo.logistic %16934 : tensor + %16936 = shape.shape_of %16935 : tensor -> tensor<2xindex> + %16937 = shape.shape_of %16934 : tensor -> tensor<2xindex> + %16938 = shape.cstr_broadcastable %16936, %16937 : tensor<2xindex>, tensor<2xindex> + %16939 = shape.assuming %16938 -> (tensor) { + %19688 = shape.broadcast %16936, %16937 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16935, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16934, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16940 = shape.shape_of %16939 : tensor -> tensor<2xindex> + %16941 = shape.cstr_broadcastable %16940, %16937 : tensor<2xindex>, tensor<2xindex> + %16942 = shape.assuming %16941 -> (tensor) { + %19688 = shape.broadcast %16940, %16937 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16939, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16934, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16943 = stablehlo.dot %16942, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5954 = tensor.dim %16915, %c0 : tensor + %16944 = arith.index_cast %dim_5954 : index to i64 + %from_elements_5955 = tensor.from_elements %16944, %c1_i64 : tensor<2xi64> + %16945 = stablehlo.dynamic_reshape %16915, %from_elements_5955 : (tensor, tensor<2xi64>) -> tensor + %dim_5956 = tensor.dim %16912, %c0 : tensor + %16946 = arith.index_cast %dim_5956 : index to i64 + %from_elements_5957 = tensor.from_elements %16946, %c1_i64 : tensor<2xi64> + %16947 = stablehlo.dynamic_reshape %16912, %from_elements_5957 : (tensor, tensor<2xi64>) -> tensor + %16948 = stablehlo.concatenate %16945, %16947, dim = 1 : (tensor, tensor) -> tensor + %16949 = "stablehlo.gather"(%16754, %16948) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %16950 = shape.shape_of %16943 : tensor -> tensor<2xindex> + %16951 = shape.shape_of %16949 : tensor -> tensor<2xindex> + %16952 = shape.cstr_broadcastable %16950, %16951 : tensor<2xindex>, tensor<2xindex> + %16953 = shape.assuming %16952 -> (tensor) { + %19688 = shape.broadcast %16950, %16951 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16943, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16949, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %16954 = shape.shape_of %16953 : tensor -> tensor<2xindex> + %16955 = stablehlo.dynamic_broadcast_in_dim %16953, %16954, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %16956 = stablehlo.dynamic_broadcast_in_dim %213, %16954, dims = [] : (tensor, tensor<2xindex>) -> tensor + %16957 = stablehlo.multiply %16955, %16956 : tensor + %dim_5958 = tensor.dim %16917, %c0 : tensor + %16958 = arith.index_cast %dim_5958 : index to i64 + %dim_5959 = tensor.dim %16953, %c0 : tensor + %16959 = arith.index_cast %dim_5959 : index to i64 + %16960 = arith.maxsi %16958, %16959 : i64 + %16961 = arith.index_cast %16960 : i64 to index + %from_elements_5960 = tensor.from_elements %16961, %c4096 : tensor<2xindex> + %16962 = stablehlo.dynamic_broadcast_in_dim %16917, %from_elements_5960, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5961 = tensor.dim %16962, %c0 : tensor + %16963 = arith.index_cast %dim_5961 : index to i64 + %from_elements_5962 = tensor.from_elements %16963, %c4096_i64 : tensor<2xi64> + %16964 = stablehlo.real_dynamic_slice %16957, %c_22, %from_elements_5962, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5963 = tensor.from_elements %16963, %c4096_i64, %c1_i64 : tensor<3xi64> + %16965 = stablehlo.dynamic_reshape %16962, %from_elements_5963 : (tensor, tensor<3xi64>) -> tensor + %16966 = stablehlo.dynamic_iota %from_elements_5963, dim = 1 : (tensor<3xi64>) -> tensor + %16967 = stablehlo.concatenate %16965, %16966, dim = 2 : (tensor, tensor) -> tensor + %16968 = "stablehlo.scatter"(%16905, %16967, %16964) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %16969 = stablehlo.slice %16714 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %16970 = stablehlo.reshape %16969 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %16971 = stablehlo.custom_call @byteir.non_zero(%16970) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5964 = tensor.dim %16971, %c0 : tensor + %16972 = arith.index_cast %dim_5964 : index to i64 + %from_elements_5965 = tensor.from_elements %16972, %c1_i64 : tensor<2xi64> + %16973 = stablehlo.real_dynamic_slice %16971, %c_22, %from_elements_5965, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5966 = tensor.dim %16973, %c0 : tensor + %16974 = arith.index_cast %dim_5966 : index to i64 + %from_elements_5967 = tensor.from_elements %16974 : tensor<1xi64> + %16975 = stablehlo.dynamic_reshape %16973, %from_elements_5967 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5968 = tensor.from_elements %16972, %c2_i64 : tensor<2xi64> + %16976 = stablehlo.real_dynamic_slice %16971, %c_24, %from_elements_5968, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5969 = tensor.dim %16976, %c0 : tensor + %16977 = arith.index_cast %dim_5969 : index to i64 + %from_elements_5970 = tensor.from_elements %16977 : tensor<1xi64> + %16978 = stablehlo.dynamic_reshape %16976, %from_elements_5970 : (tensor, tensor<1xi64>) -> tensor + %dim_5971 = tensor.dim %16978, %c0 : tensor + %16979 = arith.index_cast %dim_5971 : index to i64 + %from_elements_5972 = tensor.from_elements %16979, %c1_i64 : tensor<2xi64> + %16980 = stablehlo.dynamic_reshape %16978, %from_elements_5972 : (tensor, tensor<2xi64>) -> tensor + %dim_5973 = tensor.dim %16980, %c0 : tensor + %16981 = arith.index_cast %dim_5973 : index to i64 + %from_elements_5974 = tensor.from_elements %c1_i64, %16981, %c4096_i64 : tensor<3xi64> + %16982 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_5974, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5975 = tensor.dim %16982, %c1 : tensor<1x?x4096xi64> + %16983 = arith.index_cast %dim_5975 : index to i64 + %from_elements_5976 = tensor.from_elements %c1_i64, %16983, %c4096_i64, %c1_i64 : tensor<4xi64> + %16984 = stablehlo.dynamic_reshape %16982, %from_elements_5976 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16985 = stablehlo.dynamic_broadcast_in_dim %16980, %from_elements_5974, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5977 = tensor.dim %16985, %c1 : tensor<1x?x4096xi64> + %16986 = arith.index_cast %dim_5977 : index to i64 + %from_elements_5978 = tensor.from_elements %c1_i64, %16986, %c4096_i64, %c1_i64 : tensor<4xi64> + %16987 = stablehlo.dynamic_reshape %16985, %from_elements_5978 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16988 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_5974, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_5979 = tensor.dim %16988, %c1 : tensor<1x?x4096xi64> + %16989 = arith.index_cast %dim_5979 : index to i64 + %from_elements_5980 = tensor.from_elements %c1_i64, %16989, %c4096_i64, %c1_i64 : tensor<4xi64> + %16990 = stablehlo.dynamic_reshape %16988, %from_elements_5980 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %16991 = stablehlo.concatenate %16984, %16987, %16990, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %16992 = "stablehlo.gather"(%16725, %16991) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %16993 = shape.shape_of %16992 : tensor<1x?x4096xf32> -> tensor<3xindex> + %16994 = shape.num_elements %16993 : tensor<3xindex> -> index + %16995 = stablehlo.compute_reshape_shape %16994, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %16996 = stablehlo.dynamic_reshape %16992, %16995 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %16997 = stablehlo.dot %16996, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %16998 = stablehlo.logistic %16997 : tensor + %16999 = shape.shape_of %16998 : tensor -> tensor<2xindex> + %17000 = shape.shape_of %16997 : tensor -> tensor<2xindex> + %17001 = shape.cstr_broadcastable %16999, %17000 : tensor<2xindex>, tensor<2xindex> + %17002 = shape.assuming %17001 -> (tensor) { + %19688 = shape.broadcast %16999, %17000 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %16998, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16997, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17003 = shape.shape_of %17002 : tensor -> tensor<2xindex> + %17004 = shape.cstr_broadcastable %17003, %17000 : tensor<2xindex>, tensor<2xindex> + %17005 = shape.assuming %17004 -> (tensor) { + %19688 = shape.broadcast %17003, %17000 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17002, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %16997, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17006 = stablehlo.dot %17005, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_5981 = tensor.dim %16978, %c0 : tensor + %17007 = arith.index_cast %dim_5981 : index to i64 + %from_elements_5982 = tensor.from_elements %17007, %c1_i64 : tensor<2xi64> + %17008 = stablehlo.dynamic_reshape %16978, %from_elements_5982 : (tensor, tensor<2xi64>) -> tensor + %dim_5983 = tensor.dim %16975, %c0 : tensor + %17009 = arith.index_cast %dim_5983 : index to i64 + %from_elements_5984 = tensor.from_elements %17009, %c1_i64 : tensor<2xi64> + %17010 = stablehlo.dynamic_reshape %16975, %from_elements_5984 : (tensor, tensor<2xi64>) -> tensor + %17011 = stablehlo.concatenate %17008, %17010, dim = 1 : (tensor, tensor) -> tensor + %17012 = "stablehlo.gather"(%16754, %17011) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17013 = shape.shape_of %17006 : tensor -> tensor<2xindex> + %17014 = shape.shape_of %17012 : tensor -> tensor<2xindex> + %17015 = shape.cstr_broadcastable %17013, %17014 : tensor<2xindex>, tensor<2xindex> + %17016 = shape.assuming %17015 -> (tensor) { + %19688 = shape.broadcast %17013, %17014 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17006, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17012, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17017 = shape.shape_of %17016 : tensor -> tensor<2xindex> + %17018 = stablehlo.dynamic_broadcast_in_dim %17016, %17017, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17019 = stablehlo.dynamic_broadcast_in_dim %213, %17017, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17020 = stablehlo.multiply %17018, %17019 : tensor + %dim_5985 = tensor.dim %16980, %c0 : tensor + %17021 = arith.index_cast %dim_5985 : index to i64 + %dim_5986 = tensor.dim %17016, %c0 : tensor + %17022 = arith.index_cast %dim_5986 : index to i64 + %17023 = arith.maxsi %17021, %17022 : i64 + %17024 = arith.index_cast %17023 : i64 to index + %from_elements_5987 = tensor.from_elements %17024, %c4096 : tensor<2xindex> + %17025 = stablehlo.dynamic_broadcast_in_dim %16980, %from_elements_5987, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_5988 = tensor.dim %17025, %c0 : tensor + %17026 = arith.index_cast %dim_5988 : index to i64 + %from_elements_5989 = tensor.from_elements %17026, %c4096_i64 : tensor<2xi64> + %17027 = stablehlo.real_dynamic_slice %17020, %c_22, %from_elements_5989, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_5990 = tensor.from_elements %17026, %c4096_i64, %c1_i64 : tensor<3xi64> + %17028 = stablehlo.dynamic_reshape %17025, %from_elements_5990 : (tensor, tensor<3xi64>) -> tensor + %17029 = stablehlo.dynamic_iota %from_elements_5990, dim = 1 : (tensor<3xi64>) -> tensor + %17030 = stablehlo.concatenate %17028, %17029, dim = 2 : (tensor, tensor) -> tensor + %17031 = "stablehlo.scatter"(%16968, %17030, %17027) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17032 = stablehlo.slice %16714 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17033 = stablehlo.reshape %17032 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17034 = stablehlo.custom_call @byteir.non_zero(%17033) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_5991 = tensor.dim %17034, %c0 : tensor + %17035 = arith.index_cast %dim_5991 : index to i64 + %from_elements_5992 = tensor.from_elements %17035, %c1_i64 : tensor<2xi64> + %17036 = stablehlo.real_dynamic_slice %17034, %c_22, %from_elements_5992, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5993 = tensor.dim %17036, %c0 : tensor + %17037 = arith.index_cast %dim_5993 : index to i64 + %from_elements_5994 = tensor.from_elements %17037 : tensor<1xi64> + %17038 = stablehlo.dynamic_reshape %17036, %from_elements_5994 : (tensor, tensor<1xi64>) -> tensor + %from_elements_5995 = tensor.from_elements %17035, %c2_i64 : tensor<2xi64> + %17039 = stablehlo.real_dynamic_slice %17034, %c_24, %from_elements_5995, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_5996 = tensor.dim %17039, %c0 : tensor + %17040 = arith.index_cast %dim_5996 : index to i64 + %from_elements_5997 = tensor.from_elements %17040 : tensor<1xi64> + %17041 = stablehlo.dynamic_reshape %17039, %from_elements_5997 : (tensor, tensor<1xi64>) -> tensor + %dim_5998 = tensor.dim %17041, %c0 : tensor + %17042 = arith.index_cast %dim_5998 : index to i64 + %from_elements_5999 = tensor.from_elements %17042, %c1_i64 : tensor<2xi64> + %17043 = stablehlo.dynamic_reshape %17041, %from_elements_5999 : (tensor, tensor<2xi64>) -> tensor + %dim_6000 = tensor.dim %17043, %c0 : tensor + %17044 = arith.index_cast %dim_6000 : index to i64 + %from_elements_6001 = tensor.from_elements %c1_i64, %17044, %c4096_i64 : tensor<3xi64> + %17045 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6001, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6002 = tensor.dim %17045, %c1 : tensor<1x?x4096xi64> + %17046 = arith.index_cast %dim_6002 : index to i64 + %from_elements_6003 = tensor.from_elements %c1_i64, %17046, %c4096_i64, %c1_i64 : tensor<4xi64> + %17047 = stablehlo.dynamic_reshape %17045, %from_elements_6003 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17048 = stablehlo.dynamic_broadcast_in_dim %17043, %from_elements_6001, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6004 = tensor.dim %17048, %c1 : tensor<1x?x4096xi64> + %17049 = arith.index_cast %dim_6004 : index to i64 + %from_elements_6005 = tensor.from_elements %c1_i64, %17049, %c4096_i64, %c1_i64 : tensor<4xi64> + %17050 = stablehlo.dynamic_reshape %17048, %from_elements_6005 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17051 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6001, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6006 = tensor.dim %17051, %c1 : tensor<1x?x4096xi64> + %17052 = arith.index_cast %dim_6006 : index to i64 + %from_elements_6007 = tensor.from_elements %c1_i64, %17052, %c4096_i64, %c1_i64 : tensor<4xi64> + %17053 = stablehlo.dynamic_reshape %17051, %from_elements_6007 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17054 = stablehlo.concatenate %17047, %17050, %17053, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17055 = "stablehlo.gather"(%16725, %17054) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17056 = shape.shape_of %17055 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17057 = shape.num_elements %17056 : tensor<3xindex> -> index + %17058 = stablehlo.compute_reshape_shape %17057, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17059 = stablehlo.dynamic_reshape %17055, %17058 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17060 = stablehlo.dot %17059, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17061 = stablehlo.logistic %17060 : tensor + %17062 = shape.shape_of %17061 : tensor -> tensor<2xindex> + %17063 = shape.shape_of %17060 : tensor -> tensor<2xindex> + %17064 = shape.cstr_broadcastable %17062, %17063 : tensor<2xindex>, tensor<2xindex> + %17065 = shape.assuming %17064 -> (tensor) { + %19688 = shape.broadcast %17062, %17063 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17061, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17060, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17066 = shape.shape_of %17065 : tensor -> tensor<2xindex> + %17067 = shape.cstr_broadcastable %17066, %17063 : tensor<2xindex>, tensor<2xindex> + %17068 = shape.assuming %17067 -> (tensor) { + %19688 = shape.broadcast %17066, %17063 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17065, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17060, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17069 = stablehlo.dot %17068, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6008 = tensor.dim %17041, %c0 : tensor + %17070 = arith.index_cast %dim_6008 : index to i64 + %from_elements_6009 = tensor.from_elements %17070, %c1_i64 : tensor<2xi64> + %17071 = stablehlo.dynamic_reshape %17041, %from_elements_6009 : (tensor, tensor<2xi64>) -> tensor + %dim_6010 = tensor.dim %17038, %c0 : tensor + %17072 = arith.index_cast %dim_6010 : index to i64 + %from_elements_6011 = tensor.from_elements %17072, %c1_i64 : tensor<2xi64> + %17073 = stablehlo.dynamic_reshape %17038, %from_elements_6011 : (tensor, tensor<2xi64>) -> tensor + %17074 = stablehlo.concatenate %17071, %17073, dim = 1 : (tensor, tensor) -> tensor + %17075 = "stablehlo.gather"(%16754, %17074) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17076 = shape.shape_of %17069 : tensor -> tensor<2xindex> + %17077 = shape.shape_of %17075 : tensor -> tensor<2xindex> + %17078 = shape.cstr_broadcastable %17076, %17077 : tensor<2xindex>, tensor<2xindex> + %17079 = shape.assuming %17078 -> (tensor) { + %19688 = shape.broadcast %17076, %17077 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17069, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17075, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17080 = shape.shape_of %17079 : tensor -> tensor<2xindex> + %17081 = stablehlo.dynamic_broadcast_in_dim %17079, %17080, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17082 = stablehlo.dynamic_broadcast_in_dim %213, %17080, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17083 = stablehlo.multiply %17081, %17082 : tensor + %dim_6012 = tensor.dim %17043, %c0 : tensor + %17084 = arith.index_cast %dim_6012 : index to i64 + %dim_6013 = tensor.dim %17079, %c0 : tensor + %17085 = arith.index_cast %dim_6013 : index to i64 + %17086 = arith.maxsi %17084, %17085 : i64 + %17087 = arith.index_cast %17086 : i64 to index + %from_elements_6014 = tensor.from_elements %17087, %c4096 : tensor<2xindex> + %17088 = stablehlo.dynamic_broadcast_in_dim %17043, %from_elements_6014, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6015 = tensor.dim %17088, %c0 : tensor + %17089 = arith.index_cast %dim_6015 : index to i64 + %from_elements_6016 = tensor.from_elements %17089, %c4096_i64 : tensor<2xi64> + %17090 = stablehlo.real_dynamic_slice %17083, %c_22, %from_elements_6016, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6017 = tensor.from_elements %17089, %c4096_i64, %c1_i64 : tensor<3xi64> + %17091 = stablehlo.dynamic_reshape %17088, %from_elements_6017 : (tensor, tensor<3xi64>) -> tensor + %17092 = stablehlo.dynamic_iota %from_elements_6017, dim = 1 : (tensor<3xi64>) -> tensor + %17093 = stablehlo.concatenate %17091, %17092, dim = 2 : (tensor, tensor) -> tensor + %17094 = "stablehlo.scatter"(%17031, %17093, %17090) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17095 = stablehlo.slice %16714 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17096 = stablehlo.reshape %17095 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17097 = stablehlo.custom_call @byteir.non_zero(%17096) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6018 = tensor.dim %17097, %c0 : tensor + %17098 = arith.index_cast %dim_6018 : index to i64 + %from_elements_6019 = tensor.from_elements %17098, %c1_i64 : tensor<2xi64> + %17099 = stablehlo.real_dynamic_slice %17097, %c_22, %from_elements_6019, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6020 = tensor.dim %17099, %c0 : tensor + %17100 = arith.index_cast %dim_6020 : index to i64 + %from_elements_6021 = tensor.from_elements %17100 : tensor<1xi64> + %17101 = stablehlo.dynamic_reshape %17099, %from_elements_6021 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6022 = tensor.from_elements %17098, %c2_i64 : tensor<2xi64> + %17102 = stablehlo.real_dynamic_slice %17097, %c_24, %from_elements_6022, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6023 = tensor.dim %17102, %c0 : tensor + %17103 = arith.index_cast %dim_6023 : index to i64 + %from_elements_6024 = tensor.from_elements %17103 : tensor<1xi64> + %17104 = stablehlo.dynamic_reshape %17102, %from_elements_6024 : (tensor, tensor<1xi64>) -> tensor + %dim_6025 = tensor.dim %17104, %c0 : tensor + %17105 = arith.index_cast %dim_6025 : index to i64 + %from_elements_6026 = tensor.from_elements %17105, %c1_i64 : tensor<2xi64> + %17106 = stablehlo.dynamic_reshape %17104, %from_elements_6026 : (tensor, tensor<2xi64>) -> tensor + %dim_6027 = tensor.dim %17106, %c0 : tensor + %17107 = arith.index_cast %dim_6027 : index to i64 + %from_elements_6028 = tensor.from_elements %c1_i64, %17107, %c4096_i64 : tensor<3xi64> + %17108 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6028, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6029 = tensor.dim %17108, %c1 : tensor<1x?x4096xi64> + %17109 = arith.index_cast %dim_6029 : index to i64 + %from_elements_6030 = tensor.from_elements %c1_i64, %17109, %c4096_i64, %c1_i64 : tensor<4xi64> + %17110 = stablehlo.dynamic_reshape %17108, %from_elements_6030 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17111 = stablehlo.dynamic_broadcast_in_dim %17106, %from_elements_6028, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6031 = tensor.dim %17111, %c1 : tensor<1x?x4096xi64> + %17112 = arith.index_cast %dim_6031 : index to i64 + %from_elements_6032 = tensor.from_elements %c1_i64, %17112, %c4096_i64, %c1_i64 : tensor<4xi64> + %17113 = stablehlo.dynamic_reshape %17111, %from_elements_6032 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17114 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6028, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6033 = tensor.dim %17114, %c1 : tensor<1x?x4096xi64> + %17115 = arith.index_cast %dim_6033 : index to i64 + %from_elements_6034 = tensor.from_elements %c1_i64, %17115, %c4096_i64, %c1_i64 : tensor<4xi64> + %17116 = stablehlo.dynamic_reshape %17114, %from_elements_6034 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17117 = stablehlo.concatenate %17110, %17113, %17116, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17118 = "stablehlo.gather"(%16725, %17117) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17119 = shape.shape_of %17118 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17120 = shape.num_elements %17119 : tensor<3xindex> -> index + %17121 = stablehlo.compute_reshape_shape %17120, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17122 = stablehlo.dynamic_reshape %17118, %17121 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17123 = stablehlo.dot %17122, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17124 = stablehlo.logistic %17123 : tensor + %17125 = shape.shape_of %17124 : tensor -> tensor<2xindex> + %17126 = shape.shape_of %17123 : tensor -> tensor<2xindex> + %17127 = shape.cstr_broadcastable %17125, %17126 : tensor<2xindex>, tensor<2xindex> + %17128 = shape.assuming %17127 -> (tensor) { + %19688 = shape.broadcast %17125, %17126 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17124, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17123, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17129 = shape.shape_of %17128 : tensor -> tensor<2xindex> + %17130 = shape.cstr_broadcastable %17129, %17126 : tensor<2xindex>, tensor<2xindex> + %17131 = shape.assuming %17130 -> (tensor) { + %19688 = shape.broadcast %17129, %17126 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17128, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17123, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17132 = stablehlo.dot %17131, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6035 = tensor.dim %17104, %c0 : tensor + %17133 = arith.index_cast %dim_6035 : index to i64 + %from_elements_6036 = tensor.from_elements %17133, %c1_i64 : tensor<2xi64> + %17134 = stablehlo.dynamic_reshape %17104, %from_elements_6036 : (tensor, tensor<2xi64>) -> tensor + %dim_6037 = tensor.dim %17101, %c0 : tensor + %17135 = arith.index_cast %dim_6037 : index to i64 + %from_elements_6038 = tensor.from_elements %17135, %c1_i64 : tensor<2xi64> + %17136 = stablehlo.dynamic_reshape %17101, %from_elements_6038 : (tensor, tensor<2xi64>) -> tensor + %17137 = stablehlo.concatenate %17134, %17136, dim = 1 : (tensor, tensor) -> tensor + %17138 = "stablehlo.gather"(%16754, %17137) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17139 = shape.shape_of %17132 : tensor -> tensor<2xindex> + %17140 = shape.shape_of %17138 : tensor -> tensor<2xindex> + %17141 = shape.cstr_broadcastable %17139, %17140 : tensor<2xindex>, tensor<2xindex> + %17142 = shape.assuming %17141 -> (tensor) { + %19688 = shape.broadcast %17139, %17140 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17132, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17138, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17143 = shape.shape_of %17142 : tensor -> tensor<2xindex> + %17144 = stablehlo.dynamic_broadcast_in_dim %17142, %17143, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17145 = stablehlo.dynamic_broadcast_in_dim %213, %17143, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17146 = stablehlo.multiply %17144, %17145 : tensor + %dim_6039 = tensor.dim %17106, %c0 : tensor + %17147 = arith.index_cast %dim_6039 : index to i64 + %dim_6040 = tensor.dim %17142, %c0 : tensor + %17148 = arith.index_cast %dim_6040 : index to i64 + %17149 = arith.maxsi %17147, %17148 : i64 + %17150 = arith.index_cast %17149 : i64 to index + %from_elements_6041 = tensor.from_elements %17150, %c4096 : tensor<2xindex> + %17151 = stablehlo.dynamic_broadcast_in_dim %17106, %from_elements_6041, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6042 = tensor.dim %17151, %c0 : tensor + %17152 = arith.index_cast %dim_6042 : index to i64 + %from_elements_6043 = tensor.from_elements %17152, %c4096_i64 : tensor<2xi64> + %17153 = stablehlo.real_dynamic_slice %17146, %c_22, %from_elements_6043, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6044 = tensor.from_elements %17152, %c4096_i64, %c1_i64 : tensor<3xi64> + %17154 = stablehlo.dynamic_reshape %17151, %from_elements_6044 : (tensor, tensor<3xi64>) -> tensor + %17155 = stablehlo.dynamic_iota %from_elements_6044, dim = 1 : (tensor<3xi64>) -> tensor + %17156 = stablehlo.concatenate %17154, %17155, dim = 2 : (tensor, tensor) -> tensor + %17157 = "stablehlo.scatter"(%17094, %17156, %17153) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17158 = stablehlo.slice %16714 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17159 = stablehlo.reshape %17158 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17160 = stablehlo.custom_call @byteir.non_zero(%17159) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6045 = tensor.dim %17160, %c0 : tensor + %17161 = arith.index_cast %dim_6045 : index to i64 + %from_elements_6046 = tensor.from_elements %17161, %c1_i64 : tensor<2xi64> + %17162 = stablehlo.real_dynamic_slice %17160, %c_22, %from_elements_6046, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6047 = tensor.dim %17162, %c0 : tensor + %17163 = arith.index_cast %dim_6047 : index to i64 + %from_elements_6048 = tensor.from_elements %17163 : tensor<1xi64> + %17164 = stablehlo.dynamic_reshape %17162, %from_elements_6048 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6049 = tensor.from_elements %17161, %c2_i64 : tensor<2xi64> + %17165 = stablehlo.real_dynamic_slice %17160, %c_24, %from_elements_6049, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6050 = tensor.dim %17165, %c0 : tensor + %17166 = arith.index_cast %dim_6050 : index to i64 + %from_elements_6051 = tensor.from_elements %17166 : tensor<1xi64> + %17167 = stablehlo.dynamic_reshape %17165, %from_elements_6051 : (tensor, tensor<1xi64>) -> tensor + %dim_6052 = tensor.dim %17167, %c0 : tensor + %17168 = arith.index_cast %dim_6052 : index to i64 + %from_elements_6053 = tensor.from_elements %17168, %c1_i64 : tensor<2xi64> + %17169 = stablehlo.dynamic_reshape %17167, %from_elements_6053 : (tensor, tensor<2xi64>) -> tensor + %dim_6054 = tensor.dim %17169, %c0 : tensor + %17170 = arith.index_cast %dim_6054 : index to i64 + %from_elements_6055 = tensor.from_elements %c1_i64, %17170, %c4096_i64 : tensor<3xi64> + %17171 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6055, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6056 = tensor.dim %17171, %c1 : tensor<1x?x4096xi64> + %17172 = arith.index_cast %dim_6056 : index to i64 + %from_elements_6057 = tensor.from_elements %c1_i64, %17172, %c4096_i64, %c1_i64 : tensor<4xi64> + %17173 = stablehlo.dynamic_reshape %17171, %from_elements_6057 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17174 = stablehlo.dynamic_broadcast_in_dim %17169, %from_elements_6055, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6058 = tensor.dim %17174, %c1 : tensor<1x?x4096xi64> + %17175 = arith.index_cast %dim_6058 : index to i64 + %from_elements_6059 = tensor.from_elements %c1_i64, %17175, %c4096_i64, %c1_i64 : tensor<4xi64> + %17176 = stablehlo.dynamic_reshape %17174, %from_elements_6059 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17177 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6055, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6060 = tensor.dim %17177, %c1 : tensor<1x?x4096xi64> + %17178 = arith.index_cast %dim_6060 : index to i64 + %from_elements_6061 = tensor.from_elements %c1_i64, %17178, %c4096_i64, %c1_i64 : tensor<4xi64> + %17179 = stablehlo.dynamic_reshape %17177, %from_elements_6061 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17180 = stablehlo.concatenate %17173, %17176, %17179, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17181 = "stablehlo.gather"(%16725, %17180) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17182 = shape.shape_of %17181 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17183 = shape.num_elements %17182 : tensor<3xindex> -> index + %17184 = stablehlo.compute_reshape_shape %17183, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17185 = stablehlo.dynamic_reshape %17181, %17184 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17186 = stablehlo.dot %17185, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17187 = stablehlo.logistic %17186 : tensor + %17188 = shape.shape_of %17187 : tensor -> tensor<2xindex> + %17189 = shape.shape_of %17186 : tensor -> tensor<2xindex> + %17190 = shape.cstr_broadcastable %17188, %17189 : tensor<2xindex>, tensor<2xindex> + %17191 = shape.assuming %17190 -> (tensor) { + %19688 = shape.broadcast %17188, %17189 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17187, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17186, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17192 = shape.shape_of %17191 : tensor -> tensor<2xindex> + %17193 = shape.cstr_broadcastable %17192, %17189 : tensor<2xindex>, tensor<2xindex> + %17194 = shape.assuming %17193 -> (tensor) { + %19688 = shape.broadcast %17192, %17189 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17191, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17186, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17195 = stablehlo.dot %17194, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6062 = tensor.dim %17167, %c0 : tensor + %17196 = arith.index_cast %dim_6062 : index to i64 + %from_elements_6063 = tensor.from_elements %17196, %c1_i64 : tensor<2xi64> + %17197 = stablehlo.dynamic_reshape %17167, %from_elements_6063 : (tensor, tensor<2xi64>) -> tensor + %dim_6064 = tensor.dim %17164, %c0 : tensor + %17198 = arith.index_cast %dim_6064 : index to i64 + %from_elements_6065 = tensor.from_elements %17198, %c1_i64 : tensor<2xi64> + %17199 = stablehlo.dynamic_reshape %17164, %from_elements_6065 : (tensor, tensor<2xi64>) -> tensor + %17200 = stablehlo.concatenate %17197, %17199, dim = 1 : (tensor, tensor) -> tensor + %17201 = "stablehlo.gather"(%16754, %17200) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17202 = shape.shape_of %17195 : tensor -> tensor<2xindex> + %17203 = shape.shape_of %17201 : tensor -> tensor<2xindex> + %17204 = shape.cstr_broadcastable %17202, %17203 : tensor<2xindex>, tensor<2xindex> + %17205 = shape.assuming %17204 -> (tensor) { + %19688 = shape.broadcast %17202, %17203 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17195, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17201, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17206 = shape.shape_of %17205 : tensor -> tensor<2xindex> + %17207 = stablehlo.dynamic_broadcast_in_dim %17205, %17206, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17208 = stablehlo.dynamic_broadcast_in_dim %213, %17206, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17209 = stablehlo.multiply %17207, %17208 : tensor + %dim_6066 = tensor.dim %17169, %c0 : tensor + %17210 = arith.index_cast %dim_6066 : index to i64 + %dim_6067 = tensor.dim %17205, %c0 : tensor + %17211 = arith.index_cast %dim_6067 : index to i64 + %17212 = arith.maxsi %17210, %17211 : i64 + %17213 = arith.index_cast %17212 : i64 to index + %from_elements_6068 = tensor.from_elements %17213, %c4096 : tensor<2xindex> + %17214 = stablehlo.dynamic_broadcast_in_dim %17169, %from_elements_6068, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6069 = tensor.dim %17214, %c0 : tensor + %17215 = arith.index_cast %dim_6069 : index to i64 + %from_elements_6070 = tensor.from_elements %17215, %c4096_i64 : tensor<2xi64> + %17216 = stablehlo.real_dynamic_slice %17209, %c_22, %from_elements_6070, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6071 = tensor.from_elements %17215, %c4096_i64, %c1_i64 : tensor<3xi64> + %17217 = stablehlo.dynamic_reshape %17214, %from_elements_6071 : (tensor, tensor<3xi64>) -> tensor + %17218 = stablehlo.dynamic_iota %from_elements_6071, dim = 1 : (tensor<3xi64>) -> tensor + %17219 = stablehlo.concatenate %17217, %17218, dim = 2 : (tensor, tensor) -> tensor + %17220 = "stablehlo.scatter"(%17157, %17219, %17216) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17221 = stablehlo.reshape %17220 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %17222 = stablehlo.add %16687, %17221 : tensor<3x1x4096xf32> + %17223 = stablehlo.broadcast_in_dim %17222, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %17224 = stablehlo.power %17223, %15 : tensor<3x1x4096xf32> + %17225 = stablehlo.reduce(%17224 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %17226 = stablehlo.reshape %17225 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %17227 = stablehlo.broadcast_in_dim %17226, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %17228 = stablehlo.divide %17227, %21 : tensor<3x1x1xf32> + %17229 = stablehlo.broadcast_in_dim %17228, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %17230 = stablehlo.add %17229, %25 : tensor<3x1x1xf32> + %17231 = stablehlo.rsqrt %17230 : tensor<3x1x1xf32> + %17232 = stablehlo.broadcast_in_dim %17231, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %17233 = stablehlo.multiply %17223, %17232 : tensor<3x1x4096xf32> + %17234 = stablehlo.broadcast_in_dim %17233, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %17235 = stablehlo.multiply %17234, %31 : tensor<3x1x4096xf32> + %17236 = stablehlo.reshape %17235 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %17237 = stablehlo.dot %17236, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %17238 = stablehlo.reshape %17237 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %17239 = stablehlo.dot %17236, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %17240 = stablehlo.reshape %17239 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %17241 = stablehlo.reshape %17238 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %17242 = stablehlo.transpose %17241, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %17243 = stablehlo.reshape %17240 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %17244 = stablehlo.transpose %17243, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %17245 = stablehlo.slice %arg56 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %17246 = stablehlo.slice %arg57 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %17247 = "stablehlo.gather"(%17245, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %17248 = stablehlo.reshape %17247 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %17249 = "stablehlo.gather"(%17246, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %17250 = stablehlo.reshape %17249 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %17251 = stablehlo.broadcast_in_dim %17242, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %17252 = stablehlo.broadcast_in_dim %17248, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %17253 = stablehlo.multiply %17251, %17252 : tensor<3x32x1x128xf32> + %17254 = stablehlo.slice %17242 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %17255 = stablehlo.slice %17242 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %17256 = stablehlo.negate %17255 : tensor<3x32x1x64xf32> + %17257 = stablehlo.concatenate %17256, %17254, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %17258 = stablehlo.broadcast_in_dim %17257, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %17259 = stablehlo.broadcast_in_dim %17250, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %17260 = stablehlo.multiply %17258, %17259 : tensor<3x32x1x128xf32> + %17261 = stablehlo.add %17253, %17260 : tensor<3x32x1x128xf32> + %17262 = stablehlo.broadcast_in_dim %17244, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %17263 = stablehlo.broadcast_in_dim %17248, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %17264 = stablehlo.multiply %17262, %17263 : tensor<3x8x1x128xf32> + %17265 = stablehlo.slice %17244 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %17266 = stablehlo.slice %17244 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %17267 = stablehlo.negate %17266 : tensor<3x8x1x64xf32> + %17268 = stablehlo.concatenate %17267, %17265, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %17269 = stablehlo.broadcast_in_dim %17268, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %17270 = stablehlo.broadcast_in_dim %17250, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %17271 = stablehlo.multiply %17269, %17270 : tensor<3x8x1x128xf32> + %17272 = stablehlo.add %17264, %17271 : tensor<3x8x1x128xf32> + %17273 = stablehlo.concatenate %arg121, %17272, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %17274 = stablehlo.concatenate %arg122, %17244, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %17275 = stablehlo.reshape %17273 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %17276 = stablehlo.broadcast_in_dim %17275, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %17277 = stablehlo.reshape %17276 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %17278 = stablehlo.reshape %17274 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %17279 = stablehlo.broadcast_in_dim %17278, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %17280 = stablehlo.reshape %17279 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %17281 = stablehlo.transpose %17277, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %17282 = stablehlo.reshape %17261 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %17283 = stablehlo.reshape %17281 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %17284 = stablehlo.broadcast_in_dim %17283, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %17285 = stablehlo.dot_general %17282, %17284, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %17286 = stablehlo.reshape %17285 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %17287 = stablehlo.broadcast_in_dim %17286, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %17288 = stablehlo.divide %17287, %89 : tensor<3x32x1x8xf32> + %17289 = stablehlo.custom_call @byteir.softmax(%17288) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %17290 = stablehlo.reshape %17289 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %17291 = stablehlo.reshape %17280 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %17292 = stablehlo.broadcast_in_dim %17291, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %17293 = stablehlo.dot_general %17290, %17292, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %17294 = stablehlo.reshape %17293 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %17295 = stablehlo.transpose %17294, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %17296 = stablehlo.reshape %17295 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %17297 = stablehlo.reshape %17296 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %17298 = stablehlo.dot %17297, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %17299 = stablehlo.reshape %17298 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %17300 = stablehlo.add %17222, %17299 : tensor<3x1x4096xf32> + %17301 = stablehlo.broadcast_in_dim %17300, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %17302 = stablehlo.power %17301, %15 : tensor<3x1x4096xf32> + %17303 = stablehlo.reduce(%17302 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %17304 = stablehlo.reshape %17303 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %17305 = stablehlo.broadcast_in_dim %17304, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %17306 = stablehlo.divide %17305, %21 : tensor<3x1x1xf32> + %17307 = stablehlo.broadcast_in_dim %17306, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %17308 = stablehlo.add %17307, %25 : tensor<3x1x1xf32> + %17309 = stablehlo.rsqrt %17308 : tensor<3x1x1xf32> + %17310 = stablehlo.broadcast_in_dim %17309, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %17311 = stablehlo.multiply %17301, %17310 : tensor<3x1x4096xf32> + %17312 = stablehlo.broadcast_in_dim %17311, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %17313 = stablehlo.multiply %17312, %31 : tensor<3x1x4096xf32> + %17314 = stablehlo.reshape %17313 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %17315 = stablehlo.dot %17314, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %17316 = stablehlo.custom_call @byteir.softmax(%17315) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %17317:2 = stablehlo.custom_call @byteir.top_k(%17316) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %17318 = stablehlo.reduce(%17317#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %17319 = stablehlo.reshape %17318 : (tensor<3xf32>) -> tensor<3x1xf32> + %17320 = stablehlo.broadcast_in_dim %17317#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %17321 = stablehlo.broadcast_in_dim %17319, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %17322 = stablehlo.divide %17320, %17321 : tensor<3x2xf32> + %17323 = stablehlo.reshape %17317#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %17324 = stablehlo.broadcast_in_dim %17323, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %17325 = stablehlo.compare EQ, %17324, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %17326 = stablehlo.convert %17325 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %17327 = stablehlo.transpose %17326, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %17328 = stablehlo.slice %17327 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17329 = stablehlo.reshape %17328 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17330 = stablehlo.custom_call @byteir.non_zero(%17329) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6072 = tensor.dim %17330, %c0 : tensor + %17331 = arith.index_cast %dim_6072 : index to i64 + %from_elements_6073 = tensor.from_elements %17331, %c1_i64 : tensor<2xi64> + %17332 = stablehlo.real_dynamic_slice %17330, %c_22, %from_elements_6073, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6074 = tensor.dim %17332, %c0 : tensor + %17333 = arith.index_cast %dim_6074 : index to i64 + %from_elements_6075 = tensor.from_elements %17333 : tensor<1xi64> + %17334 = stablehlo.dynamic_reshape %17332, %from_elements_6075 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6076 = tensor.from_elements %17331, %c2_i64 : tensor<2xi64> + %17335 = stablehlo.real_dynamic_slice %17330, %c_24, %from_elements_6076, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6077 = tensor.dim %17335, %c0 : tensor + %17336 = arith.index_cast %dim_6077 : index to i64 + %from_elements_6078 = tensor.from_elements %17336 : tensor<1xi64> + %17337 = stablehlo.dynamic_reshape %17335, %from_elements_6078 : (tensor, tensor<1xi64>) -> tensor + %17338 = stablehlo.reshape %17314 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_6079 = tensor.dim %17337, %c0 : tensor + %17339 = arith.index_cast %dim_6079 : index to i64 + %from_elements_6080 = tensor.from_elements %17339, %c1_i64 : tensor<2xi64> + %17340 = stablehlo.dynamic_reshape %17337, %from_elements_6080 : (tensor, tensor<2xi64>) -> tensor + %dim_6081 = tensor.dim %17340, %c0 : tensor + %17341 = arith.index_cast %dim_6081 : index to i64 + %from_elements_6082 = tensor.from_elements %c1_i64, %17341, %c4096_i64 : tensor<3xi64> + %17342 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6082, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6083 = tensor.dim %17342, %c1 : tensor<1x?x4096xi64> + %17343 = arith.index_cast %dim_6083 : index to i64 + %from_elements_6084 = tensor.from_elements %c1_i64, %17343, %c4096_i64, %c1_i64 : tensor<4xi64> + %17344 = stablehlo.dynamic_reshape %17342, %from_elements_6084 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17345 = stablehlo.dynamic_broadcast_in_dim %17340, %from_elements_6082, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6085 = tensor.dim %17345, %c1 : tensor<1x?x4096xi64> + %17346 = arith.index_cast %dim_6085 : index to i64 + %from_elements_6086 = tensor.from_elements %c1_i64, %17346, %c4096_i64, %c1_i64 : tensor<4xi64> + %17347 = stablehlo.dynamic_reshape %17345, %from_elements_6086 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17348 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6082, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6087 = tensor.dim %17348, %c1 : tensor<1x?x4096xi64> + %17349 = arith.index_cast %dim_6087 : index to i64 + %from_elements_6088 = tensor.from_elements %c1_i64, %17349, %c4096_i64, %c1_i64 : tensor<4xi64> + %17350 = stablehlo.dynamic_reshape %17348, %from_elements_6088 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17351 = stablehlo.concatenate %17344, %17347, %17350, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17352 = "stablehlo.gather"(%17338, %17351) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17353 = shape.shape_of %17352 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17354 = shape.num_elements %17353 : tensor<3xindex> -> index + %17355 = stablehlo.compute_reshape_shape %17354, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17356 = stablehlo.dynamic_reshape %17352, %17355 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17357 = stablehlo.dot %17356, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17358 = stablehlo.logistic %17357 : tensor + %17359 = shape.shape_of %17358 : tensor -> tensor<2xindex> + %17360 = shape.shape_of %17357 : tensor -> tensor<2xindex> + %17361 = shape.cstr_broadcastable %17359, %17360 : tensor<2xindex>, tensor<2xindex> + %17362 = shape.assuming %17361 -> (tensor) { + %19688 = shape.broadcast %17359, %17360 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17358, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17357, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17363 = shape.shape_of %17362 : tensor -> tensor<2xindex> + %17364 = shape.cstr_broadcastable %17363, %17360 : tensor<2xindex>, tensor<2xindex> + %17365 = shape.assuming %17364 -> (tensor) { + %19688 = shape.broadcast %17363, %17360 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17362, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17357, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17366 = stablehlo.dot %17365, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %17367 = stablehlo.reshape %17322 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_6089 = tensor.dim %17337, %c0 : tensor + %17368 = arith.index_cast %dim_6089 : index to i64 + %from_elements_6090 = tensor.from_elements %17368, %c1_i64 : tensor<2xi64> + %17369 = stablehlo.dynamic_reshape %17337, %from_elements_6090 : (tensor, tensor<2xi64>) -> tensor + %dim_6091 = tensor.dim %17334, %c0 : tensor + %17370 = arith.index_cast %dim_6091 : index to i64 + %from_elements_6092 = tensor.from_elements %17370, %c1_i64 : tensor<2xi64> + %17371 = stablehlo.dynamic_reshape %17334, %from_elements_6092 : (tensor, tensor<2xi64>) -> tensor + %17372 = stablehlo.concatenate %17369, %17371, dim = 1 : (tensor, tensor) -> tensor + %17373 = "stablehlo.gather"(%17367, %17372) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17374 = shape.shape_of %17366 : tensor -> tensor<2xindex> + %17375 = shape.shape_of %17373 : tensor -> tensor<2xindex> + %17376 = shape.cstr_broadcastable %17374, %17375 : tensor<2xindex>, tensor<2xindex> + %17377 = shape.assuming %17376 -> (tensor) { + %19688 = shape.broadcast %17374, %17375 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17366, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17373, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17378 = shape.shape_of %17377 : tensor -> tensor<2xindex> + %17379 = stablehlo.dynamic_broadcast_in_dim %17377, %17378, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17380 = stablehlo.dynamic_broadcast_in_dim %213, %17378, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17381 = stablehlo.multiply %17379, %17380 : tensor + %dim_6093 = tensor.dim %17340, %c0 : tensor + %17382 = arith.index_cast %dim_6093 : index to i64 + %dim_6094 = tensor.dim %17377, %c0 : tensor + %17383 = arith.index_cast %dim_6094 : index to i64 + %17384 = arith.maxsi %17382, %17383 : i64 + %17385 = arith.index_cast %17384 : i64 to index + %from_elements_6095 = tensor.from_elements %17385, %c4096 : tensor<2xindex> + %17386 = stablehlo.dynamic_broadcast_in_dim %17340, %from_elements_6095, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6096 = tensor.dim %17386, %c0 : tensor + %17387 = arith.index_cast %dim_6096 : index to i64 + %from_elements_6097 = tensor.from_elements %17387, %c4096_i64 : tensor<2xi64> + %17388 = stablehlo.real_dynamic_slice %17381, %c_22, %from_elements_6097, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6098 = tensor.from_elements %17387, %c4096_i64, %c1_i64 : tensor<3xi64> + %17389 = stablehlo.dynamic_reshape %17386, %from_elements_6098 : (tensor, tensor<3xi64>) -> tensor + %17390 = stablehlo.dynamic_iota %from_elements_6098, dim = 1 : (tensor<3xi64>) -> tensor + %17391 = stablehlo.concatenate %17389, %17390, dim = 2 : (tensor, tensor) -> tensor + %17392 = "stablehlo.scatter"(%cst_2, %17391, %17388) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17393 = stablehlo.slice %17327 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17394 = stablehlo.reshape %17393 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17395 = stablehlo.custom_call @byteir.non_zero(%17394) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6099 = tensor.dim %17395, %c0 : tensor + %17396 = arith.index_cast %dim_6099 : index to i64 + %from_elements_6100 = tensor.from_elements %17396, %c1_i64 : tensor<2xi64> + %17397 = stablehlo.real_dynamic_slice %17395, %c_22, %from_elements_6100, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6101 = tensor.dim %17397, %c0 : tensor + %17398 = arith.index_cast %dim_6101 : index to i64 + %from_elements_6102 = tensor.from_elements %17398 : tensor<1xi64> + %17399 = stablehlo.dynamic_reshape %17397, %from_elements_6102 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6103 = tensor.from_elements %17396, %c2_i64 : tensor<2xi64> + %17400 = stablehlo.real_dynamic_slice %17395, %c_24, %from_elements_6103, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6104 = tensor.dim %17400, %c0 : tensor + %17401 = arith.index_cast %dim_6104 : index to i64 + %from_elements_6105 = tensor.from_elements %17401 : tensor<1xi64> + %17402 = stablehlo.dynamic_reshape %17400, %from_elements_6105 : (tensor, tensor<1xi64>) -> tensor + %dim_6106 = tensor.dim %17402, %c0 : tensor + %17403 = arith.index_cast %dim_6106 : index to i64 + %from_elements_6107 = tensor.from_elements %17403, %c1_i64 : tensor<2xi64> + %17404 = stablehlo.dynamic_reshape %17402, %from_elements_6107 : (tensor, tensor<2xi64>) -> tensor + %dim_6108 = tensor.dim %17404, %c0 : tensor + %17405 = arith.index_cast %dim_6108 : index to i64 + %from_elements_6109 = tensor.from_elements %c1_i64, %17405, %c4096_i64 : tensor<3xi64> + %17406 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6109, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6110 = tensor.dim %17406, %c1 : tensor<1x?x4096xi64> + %17407 = arith.index_cast %dim_6110 : index to i64 + %from_elements_6111 = tensor.from_elements %c1_i64, %17407, %c4096_i64, %c1_i64 : tensor<4xi64> + %17408 = stablehlo.dynamic_reshape %17406, %from_elements_6111 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17409 = stablehlo.dynamic_broadcast_in_dim %17404, %from_elements_6109, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6112 = tensor.dim %17409, %c1 : tensor<1x?x4096xi64> + %17410 = arith.index_cast %dim_6112 : index to i64 + %from_elements_6113 = tensor.from_elements %c1_i64, %17410, %c4096_i64, %c1_i64 : tensor<4xi64> + %17411 = stablehlo.dynamic_reshape %17409, %from_elements_6113 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17412 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6109, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6114 = tensor.dim %17412, %c1 : tensor<1x?x4096xi64> + %17413 = arith.index_cast %dim_6114 : index to i64 + %from_elements_6115 = tensor.from_elements %c1_i64, %17413, %c4096_i64, %c1_i64 : tensor<4xi64> + %17414 = stablehlo.dynamic_reshape %17412, %from_elements_6115 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17415 = stablehlo.concatenate %17408, %17411, %17414, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17416 = "stablehlo.gather"(%17338, %17415) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17417 = shape.shape_of %17416 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17418 = shape.num_elements %17417 : tensor<3xindex> -> index + %17419 = stablehlo.compute_reshape_shape %17418, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17420 = stablehlo.dynamic_reshape %17416, %17419 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17421 = stablehlo.dot %17420, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17422 = stablehlo.logistic %17421 : tensor + %17423 = shape.shape_of %17422 : tensor -> tensor<2xindex> + %17424 = shape.shape_of %17421 : tensor -> tensor<2xindex> + %17425 = shape.cstr_broadcastable %17423, %17424 : tensor<2xindex>, tensor<2xindex> + %17426 = shape.assuming %17425 -> (tensor) { + %19688 = shape.broadcast %17423, %17424 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17422, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17421, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17427 = shape.shape_of %17426 : tensor -> tensor<2xindex> + %17428 = shape.cstr_broadcastable %17427, %17424 : tensor<2xindex>, tensor<2xindex> + %17429 = shape.assuming %17428 -> (tensor) { + %19688 = shape.broadcast %17427, %17424 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17426, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17421, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17430 = stablehlo.dot %17429, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6116 = tensor.dim %17402, %c0 : tensor + %17431 = arith.index_cast %dim_6116 : index to i64 + %from_elements_6117 = tensor.from_elements %17431, %c1_i64 : tensor<2xi64> + %17432 = stablehlo.dynamic_reshape %17402, %from_elements_6117 : (tensor, tensor<2xi64>) -> tensor + %dim_6118 = tensor.dim %17399, %c0 : tensor + %17433 = arith.index_cast %dim_6118 : index to i64 + %from_elements_6119 = tensor.from_elements %17433, %c1_i64 : tensor<2xi64> + %17434 = stablehlo.dynamic_reshape %17399, %from_elements_6119 : (tensor, tensor<2xi64>) -> tensor + %17435 = stablehlo.concatenate %17432, %17434, dim = 1 : (tensor, tensor) -> tensor + %17436 = "stablehlo.gather"(%17367, %17435) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17437 = shape.shape_of %17430 : tensor -> tensor<2xindex> + %17438 = shape.shape_of %17436 : tensor -> tensor<2xindex> + %17439 = shape.cstr_broadcastable %17437, %17438 : tensor<2xindex>, tensor<2xindex> + %17440 = shape.assuming %17439 -> (tensor) { + %19688 = shape.broadcast %17437, %17438 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17430, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17436, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17441 = shape.shape_of %17440 : tensor -> tensor<2xindex> + %17442 = stablehlo.dynamic_broadcast_in_dim %17440, %17441, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17443 = stablehlo.dynamic_broadcast_in_dim %213, %17441, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17444 = stablehlo.multiply %17442, %17443 : tensor + %dim_6120 = tensor.dim %17404, %c0 : tensor + %17445 = arith.index_cast %dim_6120 : index to i64 + %dim_6121 = tensor.dim %17440, %c0 : tensor + %17446 = arith.index_cast %dim_6121 : index to i64 + %17447 = arith.maxsi %17445, %17446 : i64 + %17448 = arith.index_cast %17447 : i64 to index + %from_elements_6122 = tensor.from_elements %17448, %c4096 : tensor<2xindex> + %17449 = stablehlo.dynamic_broadcast_in_dim %17404, %from_elements_6122, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6123 = tensor.dim %17449, %c0 : tensor + %17450 = arith.index_cast %dim_6123 : index to i64 + %from_elements_6124 = tensor.from_elements %17450, %c4096_i64 : tensor<2xi64> + %17451 = stablehlo.real_dynamic_slice %17444, %c_22, %from_elements_6124, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6125 = tensor.from_elements %17450, %c4096_i64, %c1_i64 : tensor<3xi64> + %17452 = stablehlo.dynamic_reshape %17449, %from_elements_6125 : (tensor, tensor<3xi64>) -> tensor + %17453 = stablehlo.dynamic_iota %from_elements_6125, dim = 1 : (tensor<3xi64>) -> tensor + %17454 = stablehlo.concatenate %17452, %17453, dim = 2 : (tensor, tensor) -> tensor + %17455 = "stablehlo.scatter"(%17392, %17454, %17451) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17456 = stablehlo.slice %17327 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17457 = stablehlo.reshape %17456 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17458 = stablehlo.custom_call @byteir.non_zero(%17457) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6126 = tensor.dim %17458, %c0 : tensor + %17459 = arith.index_cast %dim_6126 : index to i64 + %from_elements_6127 = tensor.from_elements %17459, %c1_i64 : tensor<2xi64> + %17460 = stablehlo.real_dynamic_slice %17458, %c_22, %from_elements_6127, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6128 = tensor.dim %17460, %c0 : tensor + %17461 = arith.index_cast %dim_6128 : index to i64 + %from_elements_6129 = tensor.from_elements %17461 : tensor<1xi64> + %17462 = stablehlo.dynamic_reshape %17460, %from_elements_6129 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6130 = tensor.from_elements %17459, %c2_i64 : tensor<2xi64> + %17463 = stablehlo.real_dynamic_slice %17458, %c_24, %from_elements_6130, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6131 = tensor.dim %17463, %c0 : tensor + %17464 = arith.index_cast %dim_6131 : index to i64 + %from_elements_6132 = tensor.from_elements %17464 : tensor<1xi64> + %17465 = stablehlo.dynamic_reshape %17463, %from_elements_6132 : (tensor, tensor<1xi64>) -> tensor + %dim_6133 = tensor.dim %17465, %c0 : tensor + %17466 = arith.index_cast %dim_6133 : index to i64 + %from_elements_6134 = tensor.from_elements %17466, %c1_i64 : tensor<2xi64> + %17467 = stablehlo.dynamic_reshape %17465, %from_elements_6134 : (tensor, tensor<2xi64>) -> tensor + %dim_6135 = tensor.dim %17467, %c0 : tensor + %17468 = arith.index_cast %dim_6135 : index to i64 + %from_elements_6136 = tensor.from_elements %c1_i64, %17468, %c4096_i64 : tensor<3xi64> + %17469 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6136, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6137 = tensor.dim %17469, %c1 : tensor<1x?x4096xi64> + %17470 = arith.index_cast %dim_6137 : index to i64 + %from_elements_6138 = tensor.from_elements %c1_i64, %17470, %c4096_i64, %c1_i64 : tensor<4xi64> + %17471 = stablehlo.dynamic_reshape %17469, %from_elements_6138 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17472 = stablehlo.dynamic_broadcast_in_dim %17467, %from_elements_6136, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6139 = tensor.dim %17472, %c1 : tensor<1x?x4096xi64> + %17473 = arith.index_cast %dim_6139 : index to i64 + %from_elements_6140 = tensor.from_elements %c1_i64, %17473, %c4096_i64, %c1_i64 : tensor<4xi64> + %17474 = stablehlo.dynamic_reshape %17472, %from_elements_6140 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17475 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6136, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6141 = tensor.dim %17475, %c1 : tensor<1x?x4096xi64> + %17476 = arith.index_cast %dim_6141 : index to i64 + %from_elements_6142 = tensor.from_elements %c1_i64, %17476, %c4096_i64, %c1_i64 : tensor<4xi64> + %17477 = stablehlo.dynamic_reshape %17475, %from_elements_6142 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17478 = stablehlo.concatenate %17471, %17474, %17477, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17479 = "stablehlo.gather"(%17338, %17478) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17480 = shape.shape_of %17479 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17481 = shape.num_elements %17480 : tensor<3xindex> -> index + %17482 = stablehlo.compute_reshape_shape %17481, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17483 = stablehlo.dynamic_reshape %17479, %17482 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17484 = stablehlo.dot %17483, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17485 = stablehlo.logistic %17484 : tensor + %17486 = shape.shape_of %17485 : tensor -> tensor<2xindex> + %17487 = shape.shape_of %17484 : tensor -> tensor<2xindex> + %17488 = shape.cstr_broadcastable %17486, %17487 : tensor<2xindex>, tensor<2xindex> + %17489 = shape.assuming %17488 -> (tensor) { + %19688 = shape.broadcast %17486, %17487 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17485, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17484, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17490 = shape.shape_of %17489 : tensor -> tensor<2xindex> + %17491 = shape.cstr_broadcastable %17490, %17487 : tensor<2xindex>, tensor<2xindex> + %17492 = shape.assuming %17491 -> (tensor) { + %19688 = shape.broadcast %17490, %17487 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17489, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17484, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17493 = stablehlo.dot %17492, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6143 = tensor.dim %17465, %c0 : tensor + %17494 = arith.index_cast %dim_6143 : index to i64 + %from_elements_6144 = tensor.from_elements %17494, %c1_i64 : tensor<2xi64> + %17495 = stablehlo.dynamic_reshape %17465, %from_elements_6144 : (tensor, tensor<2xi64>) -> tensor + %dim_6145 = tensor.dim %17462, %c0 : tensor + %17496 = arith.index_cast %dim_6145 : index to i64 + %from_elements_6146 = tensor.from_elements %17496, %c1_i64 : tensor<2xi64> + %17497 = stablehlo.dynamic_reshape %17462, %from_elements_6146 : (tensor, tensor<2xi64>) -> tensor + %17498 = stablehlo.concatenate %17495, %17497, dim = 1 : (tensor, tensor) -> tensor + %17499 = "stablehlo.gather"(%17367, %17498) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17500 = shape.shape_of %17493 : tensor -> tensor<2xindex> + %17501 = shape.shape_of %17499 : tensor -> tensor<2xindex> + %17502 = shape.cstr_broadcastable %17500, %17501 : tensor<2xindex>, tensor<2xindex> + %17503 = shape.assuming %17502 -> (tensor) { + %19688 = shape.broadcast %17500, %17501 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17493, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17499, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17504 = shape.shape_of %17503 : tensor -> tensor<2xindex> + %17505 = stablehlo.dynamic_broadcast_in_dim %17503, %17504, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17506 = stablehlo.dynamic_broadcast_in_dim %213, %17504, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17507 = stablehlo.multiply %17505, %17506 : tensor + %dim_6147 = tensor.dim %17467, %c0 : tensor + %17508 = arith.index_cast %dim_6147 : index to i64 + %dim_6148 = tensor.dim %17503, %c0 : tensor + %17509 = arith.index_cast %dim_6148 : index to i64 + %17510 = arith.maxsi %17508, %17509 : i64 + %17511 = arith.index_cast %17510 : i64 to index + %from_elements_6149 = tensor.from_elements %17511, %c4096 : tensor<2xindex> + %17512 = stablehlo.dynamic_broadcast_in_dim %17467, %from_elements_6149, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6150 = tensor.dim %17512, %c0 : tensor + %17513 = arith.index_cast %dim_6150 : index to i64 + %from_elements_6151 = tensor.from_elements %17513, %c4096_i64 : tensor<2xi64> + %17514 = stablehlo.real_dynamic_slice %17507, %c_22, %from_elements_6151, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6152 = tensor.from_elements %17513, %c4096_i64, %c1_i64 : tensor<3xi64> + %17515 = stablehlo.dynamic_reshape %17512, %from_elements_6152 : (tensor, tensor<3xi64>) -> tensor + %17516 = stablehlo.dynamic_iota %from_elements_6152, dim = 1 : (tensor<3xi64>) -> tensor + %17517 = stablehlo.concatenate %17515, %17516, dim = 2 : (tensor, tensor) -> tensor + %17518 = "stablehlo.scatter"(%17455, %17517, %17514) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17519 = stablehlo.slice %17327 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17520 = stablehlo.reshape %17519 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17521 = stablehlo.custom_call @byteir.non_zero(%17520) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6153 = tensor.dim %17521, %c0 : tensor + %17522 = arith.index_cast %dim_6153 : index to i64 + %from_elements_6154 = tensor.from_elements %17522, %c1_i64 : tensor<2xi64> + %17523 = stablehlo.real_dynamic_slice %17521, %c_22, %from_elements_6154, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6155 = tensor.dim %17523, %c0 : tensor + %17524 = arith.index_cast %dim_6155 : index to i64 + %from_elements_6156 = tensor.from_elements %17524 : tensor<1xi64> + %17525 = stablehlo.dynamic_reshape %17523, %from_elements_6156 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6157 = tensor.from_elements %17522, %c2_i64 : tensor<2xi64> + %17526 = stablehlo.real_dynamic_slice %17521, %c_24, %from_elements_6157, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6158 = tensor.dim %17526, %c0 : tensor + %17527 = arith.index_cast %dim_6158 : index to i64 + %from_elements_6159 = tensor.from_elements %17527 : tensor<1xi64> + %17528 = stablehlo.dynamic_reshape %17526, %from_elements_6159 : (tensor, tensor<1xi64>) -> tensor + %dim_6160 = tensor.dim %17528, %c0 : tensor + %17529 = arith.index_cast %dim_6160 : index to i64 + %from_elements_6161 = tensor.from_elements %17529, %c1_i64 : tensor<2xi64> + %17530 = stablehlo.dynamic_reshape %17528, %from_elements_6161 : (tensor, tensor<2xi64>) -> tensor + %dim_6162 = tensor.dim %17530, %c0 : tensor + %17531 = arith.index_cast %dim_6162 : index to i64 + %from_elements_6163 = tensor.from_elements %c1_i64, %17531, %c4096_i64 : tensor<3xi64> + %17532 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6163, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6164 = tensor.dim %17532, %c1 : tensor<1x?x4096xi64> + %17533 = arith.index_cast %dim_6164 : index to i64 + %from_elements_6165 = tensor.from_elements %c1_i64, %17533, %c4096_i64, %c1_i64 : tensor<4xi64> + %17534 = stablehlo.dynamic_reshape %17532, %from_elements_6165 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17535 = stablehlo.dynamic_broadcast_in_dim %17530, %from_elements_6163, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6166 = tensor.dim %17535, %c1 : tensor<1x?x4096xi64> + %17536 = arith.index_cast %dim_6166 : index to i64 + %from_elements_6167 = tensor.from_elements %c1_i64, %17536, %c4096_i64, %c1_i64 : tensor<4xi64> + %17537 = stablehlo.dynamic_reshape %17535, %from_elements_6167 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17538 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6163, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6168 = tensor.dim %17538, %c1 : tensor<1x?x4096xi64> + %17539 = arith.index_cast %dim_6168 : index to i64 + %from_elements_6169 = tensor.from_elements %c1_i64, %17539, %c4096_i64, %c1_i64 : tensor<4xi64> + %17540 = stablehlo.dynamic_reshape %17538, %from_elements_6169 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17541 = stablehlo.concatenate %17534, %17537, %17540, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17542 = "stablehlo.gather"(%17338, %17541) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17543 = shape.shape_of %17542 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17544 = shape.num_elements %17543 : tensor<3xindex> -> index + %17545 = stablehlo.compute_reshape_shape %17544, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17546 = stablehlo.dynamic_reshape %17542, %17545 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17547 = stablehlo.dot %17546, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17548 = stablehlo.logistic %17547 : tensor + %17549 = shape.shape_of %17548 : tensor -> tensor<2xindex> + %17550 = shape.shape_of %17547 : tensor -> tensor<2xindex> + %17551 = shape.cstr_broadcastable %17549, %17550 : tensor<2xindex>, tensor<2xindex> + %17552 = shape.assuming %17551 -> (tensor) { + %19688 = shape.broadcast %17549, %17550 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17548, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17547, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17553 = shape.shape_of %17552 : tensor -> tensor<2xindex> + %17554 = shape.cstr_broadcastable %17553, %17550 : tensor<2xindex>, tensor<2xindex> + %17555 = shape.assuming %17554 -> (tensor) { + %19688 = shape.broadcast %17553, %17550 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17552, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17547, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17556 = stablehlo.dot %17555, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6170 = tensor.dim %17528, %c0 : tensor + %17557 = arith.index_cast %dim_6170 : index to i64 + %from_elements_6171 = tensor.from_elements %17557, %c1_i64 : tensor<2xi64> + %17558 = stablehlo.dynamic_reshape %17528, %from_elements_6171 : (tensor, tensor<2xi64>) -> tensor + %dim_6172 = tensor.dim %17525, %c0 : tensor + %17559 = arith.index_cast %dim_6172 : index to i64 + %from_elements_6173 = tensor.from_elements %17559, %c1_i64 : tensor<2xi64> + %17560 = stablehlo.dynamic_reshape %17525, %from_elements_6173 : (tensor, tensor<2xi64>) -> tensor + %17561 = stablehlo.concatenate %17558, %17560, dim = 1 : (tensor, tensor) -> tensor + %17562 = "stablehlo.gather"(%17367, %17561) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17563 = shape.shape_of %17556 : tensor -> tensor<2xindex> + %17564 = shape.shape_of %17562 : tensor -> tensor<2xindex> + %17565 = shape.cstr_broadcastable %17563, %17564 : tensor<2xindex>, tensor<2xindex> + %17566 = shape.assuming %17565 -> (tensor) { + %19688 = shape.broadcast %17563, %17564 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17556, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17562, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17567 = shape.shape_of %17566 : tensor -> tensor<2xindex> + %17568 = stablehlo.dynamic_broadcast_in_dim %17566, %17567, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17569 = stablehlo.dynamic_broadcast_in_dim %213, %17567, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17570 = stablehlo.multiply %17568, %17569 : tensor + %dim_6174 = tensor.dim %17530, %c0 : tensor + %17571 = arith.index_cast %dim_6174 : index to i64 + %dim_6175 = tensor.dim %17566, %c0 : tensor + %17572 = arith.index_cast %dim_6175 : index to i64 + %17573 = arith.maxsi %17571, %17572 : i64 + %17574 = arith.index_cast %17573 : i64 to index + %from_elements_6176 = tensor.from_elements %17574, %c4096 : tensor<2xindex> + %17575 = stablehlo.dynamic_broadcast_in_dim %17530, %from_elements_6176, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6177 = tensor.dim %17575, %c0 : tensor + %17576 = arith.index_cast %dim_6177 : index to i64 + %from_elements_6178 = tensor.from_elements %17576, %c4096_i64 : tensor<2xi64> + %17577 = stablehlo.real_dynamic_slice %17570, %c_22, %from_elements_6178, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6179 = tensor.from_elements %17576, %c4096_i64, %c1_i64 : tensor<3xi64> + %17578 = stablehlo.dynamic_reshape %17575, %from_elements_6179 : (tensor, tensor<3xi64>) -> tensor + %17579 = stablehlo.dynamic_iota %from_elements_6179, dim = 1 : (tensor<3xi64>) -> tensor + %17580 = stablehlo.concatenate %17578, %17579, dim = 2 : (tensor, tensor) -> tensor + %17581 = "stablehlo.scatter"(%17518, %17580, %17577) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17582 = stablehlo.slice %17327 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17583 = stablehlo.reshape %17582 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17584 = stablehlo.custom_call @byteir.non_zero(%17583) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6180 = tensor.dim %17584, %c0 : tensor + %17585 = arith.index_cast %dim_6180 : index to i64 + %from_elements_6181 = tensor.from_elements %17585, %c1_i64 : tensor<2xi64> + %17586 = stablehlo.real_dynamic_slice %17584, %c_22, %from_elements_6181, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6182 = tensor.dim %17586, %c0 : tensor + %17587 = arith.index_cast %dim_6182 : index to i64 + %from_elements_6183 = tensor.from_elements %17587 : tensor<1xi64> + %17588 = stablehlo.dynamic_reshape %17586, %from_elements_6183 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6184 = tensor.from_elements %17585, %c2_i64 : tensor<2xi64> + %17589 = stablehlo.real_dynamic_slice %17584, %c_24, %from_elements_6184, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6185 = tensor.dim %17589, %c0 : tensor + %17590 = arith.index_cast %dim_6185 : index to i64 + %from_elements_6186 = tensor.from_elements %17590 : tensor<1xi64> + %17591 = stablehlo.dynamic_reshape %17589, %from_elements_6186 : (tensor, tensor<1xi64>) -> tensor + %dim_6187 = tensor.dim %17591, %c0 : tensor + %17592 = arith.index_cast %dim_6187 : index to i64 + %from_elements_6188 = tensor.from_elements %17592, %c1_i64 : tensor<2xi64> + %17593 = stablehlo.dynamic_reshape %17591, %from_elements_6188 : (tensor, tensor<2xi64>) -> tensor + %dim_6189 = tensor.dim %17593, %c0 : tensor + %17594 = arith.index_cast %dim_6189 : index to i64 + %from_elements_6190 = tensor.from_elements %c1_i64, %17594, %c4096_i64 : tensor<3xi64> + %17595 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6190, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6191 = tensor.dim %17595, %c1 : tensor<1x?x4096xi64> + %17596 = arith.index_cast %dim_6191 : index to i64 + %from_elements_6192 = tensor.from_elements %c1_i64, %17596, %c4096_i64, %c1_i64 : tensor<4xi64> + %17597 = stablehlo.dynamic_reshape %17595, %from_elements_6192 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17598 = stablehlo.dynamic_broadcast_in_dim %17593, %from_elements_6190, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6193 = tensor.dim %17598, %c1 : tensor<1x?x4096xi64> + %17599 = arith.index_cast %dim_6193 : index to i64 + %from_elements_6194 = tensor.from_elements %c1_i64, %17599, %c4096_i64, %c1_i64 : tensor<4xi64> + %17600 = stablehlo.dynamic_reshape %17598, %from_elements_6194 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17601 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6190, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6195 = tensor.dim %17601, %c1 : tensor<1x?x4096xi64> + %17602 = arith.index_cast %dim_6195 : index to i64 + %from_elements_6196 = tensor.from_elements %c1_i64, %17602, %c4096_i64, %c1_i64 : tensor<4xi64> + %17603 = stablehlo.dynamic_reshape %17601, %from_elements_6196 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17604 = stablehlo.concatenate %17597, %17600, %17603, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17605 = "stablehlo.gather"(%17338, %17604) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17606 = shape.shape_of %17605 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17607 = shape.num_elements %17606 : tensor<3xindex> -> index + %17608 = stablehlo.compute_reshape_shape %17607, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17609 = stablehlo.dynamic_reshape %17605, %17608 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17610 = stablehlo.dot %17609, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17611 = stablehlo.logistic %17610 : tensor + %17612 = shape.shape_of %17611 : tensor -> tensor<2xindex> + %17613 = shape.shape_of %17610 : tensor -> tensor<2xindex> + %17614 = shape.cstr_broadcastable %17612, %17613 : tensor<2xindex>, tensor<2xindex> + %17615 = shape.assuming %17614 -> (tensor) { + %19688 = shape.broadcast %17612, %17613 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17611, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17610, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17616 = shape.shape_of %17615 : tensor -> tensor<2xindex> + %17617 = shape.cstr_broadcastable %17616, %17613 : tensor<2xindex>, tensor<2xindex> + %17618 = shape.assuming %17617 -> (tensor) { + %19688 = shape.broadcast %17616, %17613 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17615, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17610, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17619 = stablehlo.dot %17618, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6197 = tensor.dim %17591, %c0 : tensor + %17620 = arith.index_cast %dim_6197 : index to i64 + %from_elements_6198 = tensor.from_elements %17620, %c1_i64 : tensor<2xi64> + %17621 = stablehlo.dynamic_reshape %17591, %from_elements_6198 : (tensor, tensor<2xi64>) -> tensor + %dim_6199 = tensor.dim %17588, %c0 : tensor + %17622 = arith.index_cast %dim_6199 : index to i64 + %from_elements_6200 = tensor.from_elements %17622, %c1_i64 : tensor<2xi64> + %17623 = stablehlo.dynamic_reshape %17588, %from_elements_6200 : (tensor, tensor<2xi64>) -> tensor + %17624 = stablehlo.concatenate %17621, %17623, dim = 1 : (tensor, tensor) -> tensor + %17625 = "stablehlo.gather"(%17367, %17624) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17626 = shape.shape_of %17619 : tensor -> tensor<2xindex> + %17627 = shape.shape_of %17625 : tensor -> tensor<2xindex> + %17628 = shape.cstr_broadcastable %17626, %17627 : tensor<2xindex>, tensor<2xindex> + %17629 = shape.assuming %17628 -> (tensor) { + %19688 = shape.broadcast %17626, %17627 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17619, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17625, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17630 = shape.shape_of %17629 : tensor -> tensor<2xindex> + %17631 = stablehlo.dynamic_broadcast_in_dim %17629, %17630, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17632 = stablehlo.dynamic_broadcast_in_dim %213, %17630, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17633 = stablehlo.multiply %17631, %17632 : tensor + %dim_6201 = tensor.dim %17593, %c0 : tensor + %17634 = arith.index_cast %dim_6201 : index to i64 + %dim_6202 = tensor.dim %17629, %c0 : tensor + %17635 = arith.index_cast %dim_6202 : index to i64 + %17636 = arith.maxsi %17634, %17635 : i64 + %17637 = arith.index_cast %17636 : i64 to index + %from_elements_6203 = tensor.from_elements %17637, %c4096 : tensor<2xindex> + %17638 = stablehlo.dynamic_broadcast_in_dim %17593, %from_elements_6203, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6204 = tensor.dim %17638, %c0 : tensor + %17639 = arith.index_cast %dim_6204 : index to i64 + %from_elements_6205 = tensor.from_elements %17639, %c4096_i64 : tensor<2xi64> + %17640 = stablehlo.real_dynamic_slice %17633, %c_22, %from_elements_6205, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6206 = tensor.from_elements %17639, %c4096_i64, %c1_i64 : tensor<3xi64> + %17641 = stablehlo.dynamic_reshape %17638, %from_elements_6206 : (tensor, tensor<3xi64>) -> tensor + %17642 = stablehlo.dynamic_iota %from_elements_6206, dim = 1 : (tensor<3xi64>) -> tensor + %17643 = stablehlo.concatenate %17641, %17642, dim = 2 : (tensor, tensor) -> tensor + %17644 = "stablehlo.scatter"(%17581, %17643, %17640) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17645 = stablehlo.slice %17327 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17646 = stablehlo.reshape %17645 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17647 = stablehlo.custom_call @byteir.non_zero(%17646) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6207 = tensor.dim %17647, %c0 : tensor + %17648 = arith.index_cast %dim_6207 : index to i64 + %from_elements_6208 = tensor.from_elements %17648, %c1_i64 : tensor<2xi64> + %17649 = stablehlo.real_dynamic_slice %17647, %c_22, %from_elements_6208, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6209 = tensor.dim %17649, %c0 : tensor + %17650 = arith.index_cast %dim_6209 : index to i64 + %from_elements_6210 = tensor.from_elements %17650 : tensor<1xi64> + %17651 = stablehlo.dynamic_reshape %17649, %from_elements_6210 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6211 = tensor.from_elements %17648, %c2_i64 : tensor<2xi64> + %17652 = stablehlo.real_dynamic_slice %17647, %c_24, %from_elements_6211, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6212 = tensor.dim %17652, %c0 : tensor + %17653 = arith.index_cast %dim_6212 : index to i64 + %from_elements_6213 = tensor.from_elements %17653 : tensor<1xi64> + %17654 = stablehlo.dynamic_reshape %17652, %from_elements_6213 : (tensor, tensor<1xi64>) -> tensor + %dim_6214 = tensor.dim %17654, %c0 : tensor + %17655 = arith.index_cast %dim_6214 : index to i64 + %from_elements_6215 = tensor.from_elements %17655, %c1_i64 : tensor<2xi64> + %17656 = stablehlo.dynamic_reshape %17654, %from_elements_6215 : (tensor, tensor<2xi64>) -> tensor + %dim_6216 = tensor.dim %17656, %c0 : tensor + %17657 = arith.index_cast %dim_6216 : index to i64 + %from_elements_6217 = tensor.from_elements %c1_i64, %17657, %c4096_i64 : tensor<3xi64> + %17658 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6217, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6218 = tensor.dim %17658, %c1 : tensor<1x?x4096xi64> + %17659 = arith.index_cast %dim_6218 : index to i64 + %from_elements_6219 = tensor.from_elements %c1_i64, %17659, %c4096_i64, %c1_i64 : tensor<4xi64> + %17660 = stablehlo.dynamic_reshape %17658, %from_elements_6219 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17661 = stablehlo.dynamic_broadcast_in_dim %17656, %from_elements_6217, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6220 = tensor.dim %17661, %c1 : tensor<1x?x4096xi64> + %17662 = arith.index_cast %dim_6220 : index to i64 + %from_elements_6221 = tensor.from_elements %c1_i64, %17662, %c4096_i64, %c1_i64 : tensor<4xi64> + %17663 = stablehlo.dynamic_reshape %17661, %from_elements_6221 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17664 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6217, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6222 = tensor.dim %17664, %c1 : tensor<1x?x4096xi64> + %17665 = arith.index_cast %dim_6222 : index to i64 + %from_elements_6223 = tensor.from_elements %c1_i64, %17665, %c4096_i64, %c1_i64 : tensor<4xi64> + %17666 = stablehlo.dynamic_reshape %17664, %from_elements_6223 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17667 = stablehlo.concatenate %17660, %17663, %17666, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17668 = "stablehlo.gather"(%17338, %17667) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17669 = shape.shape_of %17668 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17670 = shape.num_elements %17669 : tensor<3xindex> -> index + %17671 = stablehlo.compute_reshape_shape %17670, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17672 = stablehlo.dynamic_reshape %17668, %17671 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17673 = stablehlo.dot %17672, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17674 = stablehlo.logistic %17673 : tensor + %17675 = shape.shape_of %17674 : tensor -> tensor<2xindex> + %17676 = shape.shape_of %17673 : tensor -> tensor<2xindex> + %17677 = shape.cstr_broadcastable %17675, %17676 : tensor<2xindex>, tensor<2xindex> + %17678 = shape.assuming %17677 -> (tensor) { + %19688 = shape.broadcast %17675, %17676 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17674, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17673, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17679 = shape.shape_of %17678 : tensor -> tensor<2xindex> + %17680 = shape.cstr_broadcastable %17679, %17676 : tensor<2xindex>, tensor<2xindex> + %17681 = shape.assuming %17680 -> (tensor) { + %19688 = shape.broadcast %17679, %17676 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17678, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17673, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17682 = stablehlo.dot %17681, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6224 = tensor.dim %17654, %c0 : tensor + %17683 = arith.index_cast %dim_6224 : index to i64 + %from_elements_6225 = tensor.from_elements %17683, %c1_i64 : tensor<2xi64> + %17684 = stablehlo.dynamic_reshape %17654, %from_elements_6225 : (tensor, tensor<2xi64>) -> tensor + %dim_6226 = tensor.dim %17651, %c0 : tensor + %17685 = arith.index_cast %dim_6226 : index to i64 + %from_elements_6227 = tensor.from_elements %17685, %c1_i64 : tensor<2xi64> + %17686 = stablehlo.dynamic_reshape %17651, %from_elements_6227 : (tensor, tensor<2xi64>) -> tensor + %17687 = stablehlo.concatenate %17684, %17686, dim = 1 : (tensor, tensor) -> tensor + %17688 = "stablehlo.gather"(%17367, %17687) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17689 = shape.shape_of %17682 : tensor -> tensor<2xindex> + %17690 = shape.shape_of %17688 : tensor -> tensor<2xindex> + %17691 = shape.cstr_broadcastable %17689, %17690 : tensor<2xindex>, tensor<2xindex> + %17692 = shape.assuming %17691 -> (tensor) { + %19688 = shape.broadcast %17689, %17690 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17682, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17688, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17693 = shape.shape_of %17692 : tensor -> tensor<2xindex> + %17694 = stablehlo.dynamic_broadcast_in_dim %17692, %17693, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17695 = stablehlo.dynamic_broadcast_in_dim %213, %17693, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17696 = stablehlo.multiply %17694, %17695 : tensor + %dim_6228 = tensor.dim %17656, %c0 : tensor + %17697 = arith.index_cast %dim_6228 : index to i64 + %dim_6229 = tensor.dim %17692, %c0 : tensor + %17698 = arith.index_cast %dim_6229 : index to i64 + %17699 = arith.maxsi %17697, %17698 : i64 + %17700 = arith.index_cast %17699 : i64 to index + %from_elements_6230 = tensor.from_elements %17700, %c4096 : tensor<2xindex> + %17701 = stablehlo.dynamic_broadcast_in_dim %17656, %from_elements_6230, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6231 = tensor.dim %17701, %c0 : tensor + %17702 = arith.index_cast %dim_6231 : index to i64 + %from_elements_6232 = tensor.from_elements %17702, %c4096_i64 : tensor<2xi64> + %17703 = stablehlo.real_dynamic_slice %17696, %c_22, %from_elements_6232, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6233 = tensor.from_elements %17702, %c4096_i64, %c1_i64 : tensor<3xi64> + %17704 = stablehlo.dynamic_reshape %17701, %from_elements_6233 : (tensor, tensor<3xi64>) -> tensor + %17705 = stablehlo.dynamic_iota %from_elements_6233, dim = 1 : (tensor<3xi64>) -> tensor + %17706 = stablehlo.concatenate %17704, %17705, dim = 2 : (tensor, tensor) -> tensor + %17707 = "stablehlo.scatter"(%17644, %17706, %17703) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17708 = stablehlo.slice %17327 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17709 = stablehlo.reshape %17708 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17710 = stablehlo.custom_call @byteir.non_zero(%17709) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6234 = tensor.dim %17710, %c0 : tensor + %17711 = arith.index_cast %dim_6234 : index to i64 + %from_elements_6235 = tensor.from_elements %17711, %c1_i64 : tensor<2xi64> + %17712 = stablehlo.real_dynamic_slice %17710, %c_22, %from_elements_6235, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6236 = tensor.dim %17712, %c0 : tensor + %17713 = arith.index_cast %dim_6236 : index to i64 + %from_elements_6237 = tensor.from_elements %17713 : tensor<1xi64> + %17714 = stablehlo.dynamic_reshape %17712, %from_elements_6237 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6238 = tensor.from_elements %17711, %c2_i64 : tensor<2xi64> + %17715 = stablehlo.real_dynamic_slice %17710, %c_24, %from_elements_6238, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6239 = tensor.dim %17715, %c0 : tensor + %17716 = arith.index_cast %dim_6239 : index to i64 + %from_elements_6240 = tensor.from_elements %17716 : tensor<1xi64> + %17717 = stablehlo.dynamic_reshape %17715, %from_elements_6240 : (tensor, tensor<1xi64>) -> tensor + %dim_6241 = tensor.dim %17717, %c0 : tensor + %17718 = arith.index_cast %dim_6241 : index to i64 + %from_elements_6242 = tensor.from_elements %17718, %c1_i64 : tensor<2xi64> + %17719 = stablehlo.dynamic_reshape %17717, %from_elements_6242 : (tensor, tensor<2xi64>) -> tensor + %dim_6243 = tensor.dim %17719, %c0 : tensor + %17720 = arith.index_cast %dim_6243 : index to i64 + %from_elements_6244 = tensor.from_elements %c1_i64, %17720, %c4096_i64 : tensor<3xi64> + %17721 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6244, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6245 = tensor.dim %17721, %c1 : tensor<1x?x4096xi64> + %17722 = arith.index_cast %dim_6245 : index to i64 + %from_elements_6246 = tensor.from_elements %c1_i64, %17722, %c4096_i64, %c1_i64 : tensor<4xi64> + %17723 = stablehlo.dynamic_reshape %17721, %from_elements_6246 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17724 = stablehlo.dynamic_broadcast_in_dim %17719, %from_elements_6244, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6247 = tensor.dim %17724, %c1 : tensor<1x?x4096xi64> + %17725 = arith.index_cast %dim_6247 : index to i64 + %from_elements_6248 = tensor.from_elements %c1_i64, %17725, %c4096_i64, %c1_i64 : tensor<4xi64> + %17726 = stablehlo.dynamic_reshape %17724, %from_elements_6248 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17727 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6244, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6249 = tensor.dim %17727, %c1 : tensor<1x?x4096xi64> + %17728 = arith.index_cast %dim_6249 : index to i64 + %from_elements_6250 = tensor.from_elements %c1_i64, %17728, %c4096_i64, %c1_i64 : tensor<4xi64> + %17729 = stablehlo.dynamic_reshape %17727, %from_elements_6250 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17730 = stablehlo.concatenate %17723, %17726, %17729, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17731 = "stablehlo.gather"(%17338, %17730) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17732 = shape.shape_of %17731 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17733 = shape.num_elements %17732 : tensor<3xindex> -> index + %17734 = stablehlo.compute_reshape_shape %17733, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17735 = stablehlo.dynamic_reshape %17731, %17734 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17736 = stablehlo.dot %17735, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17737 = stablehlo.logistic %17736 : tensor + %17738 = shape.shape_of %17737 : tensor -> tensor<2xindex> + %17739 = shape.shape_of %17736 : tensor -> tensor<2xindex> + %17740 = shape.cstr_broadcastable %17738, %17739 : tensor<2xindex>, tensor<2xindex> + %17741 = shape.assuming %17740 -> (tensor) { + %19688 = shape.broadcast %17738, %17739 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17737, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17736, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17742 = shape.shape_of %17741 : tensor -> tensor<2xindex> + %17743 = shape.cstr_broadcastable %17742, %17739 : tensor<2xindex>, tensor<2xindex> + %17744 = shape.assuming %17743 -> (tensor) { + %19688 = shape.broadcast %17742, %17739 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17741, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17736, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17745 = stablehlo.dot %17744, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6251 = tensor.dim %17717, %c0 : tensor + %17746 = arith.index_cast %dim_6251 : index to i64 + %from_elements_6252 = tensor.from_elements %17746, %c1_i64 : tensor<2xi64> + %17747 = stablehlo.dynamic_reshape %17717, %from_elements_6252 : (tensor, tensor<2xi64>) -> tensor + %dim_6253 = tensor.dim %17714, %c0 : tensor + %17748 = arith.index_cast %dim_6253 : index to i64 + %from_elements_6254 = tensor.from_elements %17748, %c1_i64 : tensor<2xi64> + %17749 = stablehlo.dynamic_reshape %17714, %from_elements_6254 : (tensor, tensor<2xi64>) -> tensor + %17750 = stablehlo.concatenate %17747, %17749, dim = 1 : (tensor, tensor) -> tensor + %17751 = "stablehlo.gather"(%17367, %17750) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17752 = shape.shape_of %17745 : tensor -> tensor<2xindex> + %17753 = shape.shape_of %17751 : tensor -> tensor<2xindex> + %17754 = shape.cstr_broadcastable %17752, %17753 : tensor<2xindex>, tensor<2xindex> + %17755 = shape.assuming %17754 -> (tensor) { + %19688 = shape.broadcast %17752, %17753 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17745, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17751, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17756 = shape.shape_of %17755 : tensor -> tensor<2xindex> + %17757 = stablehlo.dynamic_broadcast_in_dim %17755, %17756, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17758 = stablehlo.dynamic_broadcast_in_dim %213, %17756, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17759 = stablehlo.multiply %17757, %17758 : tensor + %dim_6255 = tensor.dim %17719, %c0 : tensor + %17760 = arith.index_cast %dim_6255 : index to i64 + %dim_6256 = tensor.dim %17755, %c0 : tensor + %17761 = arith.index_cast %dim_6256 : index to i64 + %17762 = arith.maxsi %17760, %17761 : i64 + %17763 = arith.index_cast %17762 : i64 to index + %from_elements_6257 = tensor.from_elements %17763, %c4096 : tensor<2xindex> + %17764 = stablehlo.dynamic_broadcast_in_dim %17719, %from_elements_6257, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6258 = tensor.dim %17764, %c0 : tensor + %17765 = arith.index_cast %dim_6258 : index to i64 + %from_elements_6259 = tensor.from_elements %17765, %c4096_i64 : tensor<2xi64> + %17766 = stablehlo.real_dynamic_slice %17759, %c_22, %from_elements_6259, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6260 = tensor.from_elements %17765, %c4096_i64, %c1_i64 : tensor<3xi64> + %17767 = stablehlo.dynamic_reshape %17764, %from_elements_6260 : (tensor, tensor<3xi64>) -> tensor + %17768 = stablehlo.dynamic_iota %from_elements_6260, dim = 1 : (tensor<3xi64>) -> tensor + %17769 = stablehlo.concatenate %17767, %17768, dim = 2 : (tensor, tensor) -> tensor + %17770 = "stablehlo.scatter"(%17707, %17769, %17766) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17771 = stablehlo.slice %17327 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17772 = stablehlo.reshape %17771 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17773 = stablehlo.custom_call @byteir.non_zero(%17772) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6261 = tensor.dim %17773, %c0 : tensor + %17774 = arith.index_cast %dim_6261 : index to i64 + %from_elements_6262 = tensor.from_elements %17774, %c1_i64 : tensor<2xi64> + %17775 = stablehlo.real_dynamic_slice %17773, %c_22, %from_elements_6262, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6263 = tensor.dim %17775, %c0 : tensor + %17776 = arith.index_cast %dim_6263 : index to i64 + %from_elements_6264 = tensor.from_elements %17776 : tensor<1xi64> + %17777 = stablehlo.dynamic_reshape %17775, %from_elements_6264 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6265 = tensor.from_elements %17774, %c2_i64 : tensor<2xi64> + %17778 = stablehlo.real_dynamic_slice %17773, %c_24, %from_elements_6265, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6266 = tensor.dim %17778, %c0 : tensor + %17779 = arith.index_cast %dim_6266 : index to i64 + %from_elements_6267 = tensor.from_elements %17779 : tensor<1xi64> + %17780 = stablehlo.dynamic_reshape %17778, %from_elements_6267 : (tensor, tensor<1xi64>) -> tensor + %dim_6268 = tensor.dim %17780, %c0 : tensor + %17781 = arith.index_cast %dim_6268 : index to i64 + %from_elements_6269 = tensor.from_elements %17781, %c1_i64 : tensor<2xi64> + %17782 = stablehlo.dynamic_reshape %17780, %from_elements_6269 : (tensor, tensor<2xi64>) -> tensor + %dim_6270 = tensor.dim %17782, %c0 : tensor + %17783 = arith.index_cast %dim_6270 : index to i64 + %from_elements_6271 = tensor.from_elements %c1_i64, %17783, %c4096_i64 : tensor<3xi64> + %17784 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6271, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6272 = tensor.dim %17784, %c1 : tensor<1x?x4096xi64> + %17785 = arith.index_cast %dim_6272 : index to i64 + %from_elements_6273 = tensor.from_elements %c1_i64, %17785, %c4096_i64, %c1_i64 : tensor<4xi64> + %17786 = stablehlo.dynamic_reshape %17784, %from_elements_6273 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17787 = stablehlo.dynamic_broadcast_in_dim %17782, %from_elements_6271, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6274 = tensor.dim %17787, %c1 : tensor<1x?x4096xi64> + %17788 = arith.index_cast %dim_6274 : index to i64 + %from_elements_6275 = tensor.from_elements %c1_i64, %17788, %c4096_i64, %c1_i64 : tensor<4xi64> + %17789 = stablehlo.dynamic_reshape %17787, %from_elements_6275 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17790 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6271, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6276 = tensor.dim %17790, %c1 : tensor<1x?x4096xi64> + %17791 = arith.index_cast %dim_6276 : index to i64 + %from_elements_6277 = tensor.from_elements %c1_i64, %17791, %c4096_i64, %c1_i64 : tensor<4xi64> + %17792 = stablehlo.dynamic_reshape %17790, %from_elements_6277 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17793 = stablehlo.concatenate %17786, %17789, %17792, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17794 = "stablehlo.gather"(%17338, %17793) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17795 = shape.shape_of %17794 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17796 = shape.num_elements %17795 : tensor<3xindex> -> index + %17797 = stablehlo.compute_reshape_shape %17796, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17798 = stablehlo.dynamic_reshape %17794, %17797 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17799 = stablehlo.dot %17798, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17800 = stablehlo.logistic %17799 : tensor + %17801 = shape.shape_of %17800 : tensor -> tensor<2xindex> + %17802 = shape.shape_of %17799 : tensor -> tensor<2xindex> + %17803 = shape.cstr_broadcastable %17801, %17802 : tensor<2xindex>, tensor<2xindex> + %17804 = shape.assuming %17803 -> (tensor) { + %19688 = shape.broadcast %17801, %17802 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17800, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17799, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17805 = shape.shape_of %17804 : tensor -> tensor<2xindex> + %17806 = shape.cstr_broadcastable %17805, %17802 : tensor<2xindex>, tensor<2xindex> + %17807 = shape.assuming %17806 -> (tensor) { + %19688 = shape.broadcast %17805, %17802 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17804, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17799, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17808 = stablehlo.dot %17807, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6278 = tensor.dim %17780, %c0 : tensor + %17809 = arith.index_cast %dim_6278 : index to i64 + %from_elements_6279 = tensor.from_elements %17809, %c1_i64 : tensor<2xi64> + %17810 = stablehlo.dynamic_reshape %17780, %from_elements_6279 : (tensor, tensor<2xi64>) -> tensor + %dim_6280 = tensor.dim %17777, %c0 : tensor + %17811 = arith.index_cast %dim_6280 : index to i64 + %from_elements_6281 = tensor.from_elements %17811, %c1_i64 : tensor<2xi64> + %17812 = stablehlo.dynamic_reshape %17777, %from_elements_6281 : (tensor, tensor<2xi64>) -> tensor + %17813 = stablehlo.concatenate %17810, %17812, dim = 1 : (tensor, tensor) -> tensor + %17814 = "stablehlo.gather"(%17367, %17813) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17815 = shape.shape_of %17808 : tensor -> tensor<2xindex> + %17816 = shape.shape_of %17814 : tensor -> tensor<2xindex> + %17817 = shape.cstr_broadcastable %17815, %17816 : tensor<2xindex>, tensor<2xindex> + %17818 = shape.assuming %17817 -> (tensor) { + %19688 = shape.broadcast %17815, %17816 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17808, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17814, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17819 = shape.shape_of %17818 : tensor -> tensor<2xindex> + %17820 = stablehlo.dynamic_broadcast_in_dim %17818, %17819, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17821 = stablehlo.dynamic_broadcast_in_dim %213, %17819, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17822 = stablehlo.multiply %17820, %17821 : tensor + %dim_6282 = tensor.dim %17782, %c0 : tensor + %17823 = arith.index_cast %dim_6282 : index to i64 + %dim_6283 = tensor.dim %17818, %c0 : tensor + %17824 = arith.index_cast %dim_6283 : index to i64 + %17825 = arith.maxsi %17823, %17824 : i64 + %17826 = arith.index_cast %17825 : i64 to index + %from_elements_6284 = tensor.from_elements %17826, %c4096 : tensor<2xindex> + %17827 = stablehlo.dynamic_broadcast_in_dim %17782, %from_elements_6284, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6285 = tensor.dim %17827, %c0 : tensor + %17828 = arith.index_cast %dim_6285 : index to i64 + %from_elements_6286 = tensor.from_elements %17828, %c4096_i64 : tensor<2xi64> + %17829 = stablehlo.real_dynamic_slice %17822, %c_22, %from_elements_6286, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6287 = tensor.from_elements %17828, %c4096_i64, %c1_i64 : tensor<3xi64> + %17830 = stablehlo.dynamic_reshape %17827, %from_elements_6287 : (tensor, tensor<3xi64>) -> tensor + %17831 = stablehlo.dynamic_iota %from_elements_6287, dim = 1 : (tensor<3xi64>) -> tensor + %17832 = stablehlo.concatenate %17830, %17831, dim = 2 : (tensor, tensor) -> tensor + %17833 = "stablehlo.scatter"(%17770, %17832, %17829) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %17834 = stablehlo.reshape %17833 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %17835 = stablehlo.add %17300, %17834 : tensor<3x1x4096xf32> + %17836 = stablehlo.broadcast_in_dim %17835, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %17837 = stablehlo.power %17836, %15 : tensor<3x1x4096xf32> + %17838 = stablehlo.reduce(%17837 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %17839 = stablehlo.reshape %17838 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %17840 = stablehlo.broadcast_in_dim %17839, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %17841 = stablehlo.divide %17840, %21 : tensor<3x1x1xf32> + %17842 = stablehlo.broadcast_in_dim %17841, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %17843 = stablehlo.add %17842, %25 : tensor<3x1x1xf32> + %17844 = stablehlo.rsqrt %17843 : tensor<3x1x1xf32> + %17845 = stablehlo.broadcast_in_dim %17844, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %17846 = stablehlo.multiply %17836, %17845 : tensor<3x1x4096xf32> + %17847 = stablehlo.broadcast_in_dim %17846, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %17848 = stablehlo.multiply %17847, %31 : tensor<3x1x4096xf32> + %17849 = stablehlo.reshape %17848 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %17850 = stablehlo.dot %17849, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %17851 = stablehlo.reshape %17850 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %17852 = stablehlo.dot %17849, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %17853 = stablehlo.reshape %17852 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %17854 = stablehlo.reshape %17851 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %17855 = stablehlo.transpose %17854, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %17856 = stablehlo.reshape %17853 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %17857 = stablehlo.transpose %17856, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %17858 = stablehlo.slice %arg58 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %17859 = stablehlo.slice %arg59 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %17860 = "stablehlo.gather"(%17858, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %17861 = stablehlo.reshape %17860 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %17862 = "stablehlo.gather"(%17859, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %17863 = stablehlo.reshape %17862 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %17864 = stablehlo.broadcast_in_dim %17855, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %17865 = stablehlo.broadcast_in_dim %17861, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %17866 = stablehlo.multiply %17864, %17865 : tensor<3x32x1x128xf32> + %17867 = stablehlo.slice %17855 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %17868 = stablehlo.slice %17855 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %17869 = stablehlo.negate %17868 : tensor<3x32x1x64xf32> + %17870 = stablehlo.concatenate %17869, %17867, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %17871 = stablehlo.broadcast_in_dim %17870, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %17872 = stablehlo.broadcast_in_dim %17863, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %17873 = stablehlo.multiply %17871, %17872 : tensor<3x32x1x128xf32> + %17874 = stablehlo.add %17866, %17873 : tensor<3x32x1x128xf32> + %17875 = stablehlo.broadcast_in_dim %17857, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %17876 = stablehlo.broadcast_in_dim %17861, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %17877 = stablehlo.multiply %17875, %17876 : tensor<3x8x1x128xf32> + %17878 = stablehlo.slice %17857 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %17879 = stablehlo.slice %17857 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %17880 = stablehlo.negate %17879 : tensor<3x8x1x64xf32> + %17881 = stablehlo.concatenate %17880, %17878, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %17882 = stablehlo.broadcast_in_dim %17881, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %17883 = stablehlo.broadcast_in_dim %17863, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %17884 = stablehlo.multiply %17882, %17883 : tensor<3x8x1x128xf32> + %17885 = stablehlo.add %17877, %17884 : tensor<3x8x1x128xf32> + %17886 = stablehlo.concatenate %arg123, %17885, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %17887 = stablehlo.concatenate %arg124, %17857, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %17888 = stablehlo.reshape %17886 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %17889 = stablehlo.broadcast_in_dim %17888, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %17890 = stablehlo.reshape %17889 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %17891 = stablehlo.reshape %17887 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %17892 = stablehlo.broadcast_in_dim %17891, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %17893 = stablehlo.reshape %17892 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %17894 = stablehlo.transpose %17890, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %17895 = stablehlo.reshape %17874 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %17896 = stablehlo.reshape %17894 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %17897 = stablehlo.broadcast_in_dim %17896, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %17898 = stablehlo.dot_general %17895, %17897, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %17899 = stablehlo.reshape %17898 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %17900 = stablehlo.broadcast_in_dim %17899, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %17901 = stablehlo.divide %17900, %89 : tensor<3x32x1x8xf32> + %17902 = stablehlo.custom_call @byteir.softmax(%17901) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %17903 = stablehlo.reshape %17902 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %17904 = stablehlo.reshape %17893 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %17905 = stablehlo.broadcast_in_dim %17904, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %17906 = stablehlo.dot_general %17903, %17905, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %17907 = stablehlo.reshape %17906 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %17908 = stablehlo.transpose %17907, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %17909 = stablehlo.reshape %17908 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %17910 = stablehlo.reshape %17909 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %17911 = stablehlo.dot %17910, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %17912 = stablehlo.reshape %17911 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %17913 = stablehlo.add %17835, %17912 : tensor<3x1x4096xf32> + %17914 = stablehlo.broadcast_in_dim %17913, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %17915 = stablehlo.power %17914, %15 : tensor<3x1x4096xf32> + %17916 = stablehlo.reduce(%17915 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %17917 = stablehlo.reshape %17916 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %17918 = stablehlo.broadcast_in_dim %17917, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %17919 = stablehlo.divide %17918, %21 : tensor<3x1x1xf32> + %17920 = stablehlo.broadcast_in_dim %17919, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %17921 = stablehlo.add %17920, %25 : tensor<3x1x1xf32> + %17922 = stablehlo.rsqrt %17921 : tensor<3x1x1xf32> + %17923 = stablehlo.broadcast_in_dim %17922, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %17924 = stablehlo.multiply %17914, %17923 : tensor<3x1x4096xf32> + %17925 = stablehlo.broadcast_in_dim %17924, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %17926 = stablehlo.multiply %17925, %31 : tensor<3x1x4096xf32> + %17927 = stablehlo.reshape %17926 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %17928 = stablehlo.dot %17927, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %17929 = stablehlo.custom_call @byteir.softmax(%17928) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %17930:2 = stablehlo.custom_call @byteir.top_k(%17929) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %17931 = stablehlo.reduce(%17930#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %17932 = stablehlo.reshape %17931 : (tensor<3xf32>) -> tensor<3x1xf32> + %17933 = stablehlo.broadcast_in_dim %17930#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %17934 = stablehlo.broadcast_in_dim %17932, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %17935 = stablehlo.divide %17933, %17934 : tensor<3x2xf32> + %17936 = stablehlo.reshape %17930#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %17937 = stablehlo.broadcast_in_dim %17936, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %17938 = stablehlo.compare EQ, %17937, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %17939 = stablehlo.convert %17938 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %17940 = stablehlo.transpose %17939, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %17941 = stablehlo.slice %17940 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %17942 = stablehlo.reshape %17941 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %17943 = stablehlo.custom_call @byteir.non_zero(%17942) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6288 = tensor.dim %17943, %c0 : tensor + %17944 = arith.index_cast %dim_6288 : index to i64 + %from_elements_6289 = tensor.from_elements %17944, %c1_i64 : tensor<2xi64> + %17945 = stablehlo.real_dynamic_slice %17943, %c_22, %from_elements_6289, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6290 = tensor.dim %17945, %c0 : tensor + %17946 = arith.index_cast %dim_6290 : index to i64 + %from_elements_6291 = tensor.from_elements %17946 : tensor<1xi64> + %17947 = stablehlo.dynamic_reshape %17945, %from_elements_6291 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6292 = tensor.from_elements %17944, %c2_i64 : tensor<2xi64> + %17948 = stablehlo.real_dynamic_slice %17943, %c_24, %from_elements_6292, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6293 = tensor.dim %17948, %c0 : tensor + %17949 = arith.index_cast %dim_6293 : index to i64 + %from_elements_6294 = tensor.from_elements %17949 : tensor<1xi64> + %17950 = stablehlo.dynamic_reshape %17948, %from_elements_6294 : (tensor, tensor<1xi64>) -> tensor + %17951 = stablehlo.reshape %17927 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_6295 = tensor.dim %17950, %c0 : tensor + %17952 = arith.index_cast %dim_6295 : index to i64 + %from_elements_6296 = tensor.from_elements %17952, %c1_i64 : tensor<2xi64> + %17953 = stablehlo.dynamic_reshape %17950, %from_elements_6296 : (tensor, tensor<2xi64>) -> tensor + %dim_6297 = tensor.dim %17953, %c0 : tensor + %17954 = arith.index_cast %dim_6297 : index to i64 + %from_elements_6298 = tensor.from_elements %c1_i64, %17954, %c4096_i64 : tensor<3xi64> + %17955 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6298, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6299 = tensor.dim %17955, %c1 : tensor<1x?x4096xi64> + %17956 = arith.index_cast %dim_6299 : index to i64 + %from_elements_6300 = tensor.from_elements %c1_i64, %17956, %c4096_i64, %c1_i64 : tensor<4xi64> + %17957 = stablehlo.dynamic_reshape %17955, %from_elements_6300 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17958 = stablehlo.dynamic_broadcast_in_dim %17953, %from_elements_6298, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6301 = tensor.dim %17958, %c1 : tensor<1x?x4096xi64> + %17959 = arith.index_cast %dim_6301 : index to i64 + %from_elements_6302 = tensor.from_elements %c1_i64, %17959, %c4096_i64, %c1_i64 : tensor<4xi64> + %17960 = stablehlo.dynamic_reshape %17958, %from_elements_6302 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17961 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6298, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6303 = tensor.dim %17961, %c1 : tensor<1x?x4096xi64> + %17962 = arith.index_cast %dim_6303 : index to i64 + %from_elements_6304 = tensor.from_elements %c1_i64, %17962, %c4096_i64, %c1_i64 : tensor<4xi64> + %17963 = stablehlo.dynamic_reshape %17961, %from_elements_6304 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %17964 = stablehlo.concatenate %17957, %17960, %17963, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %17965 = "stablehlo.gather"(%17951, %17964) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %17966 = shape.shape_of %17965 : tensor<1x?x4096xf32> -> tensor<3xindex> + %17967 = shape.num_elements %17966 : tensor<3xindex> -> index + %17968 = stablehlo.compute_reshape_shape %17967, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %17969 = stablehlo.dynamic_reshape %17965, %17968 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %17970 = stablehlo.dot %17969, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %17971 = stablehlo.logistic %17970 : tensor + %17972 = shape.shape_of %17971 : tensor -> tensor<2xindex> + %17973 = shape.shape_of %17970 : tensor -> tensor<2xindex> + %17974 = shape.cstr_broadcastable %17972, %17973 : tensor<2xindex>, tensor<2xindex> + %17975 = shape.assuming %17974 -> (tensor) { + %19688 = shape.broadcast %17972, %17973 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17971, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17970, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17976 = shape.shape_of %17975 : tensor -> tensor<2xindex> + %17977 = shape.cstr_broadcastable %17976, %17973 : tensor<2xindex>, tensor<2xindex> + %17978 = shape.assuming %17977 -> (tensor) { + %19688 = shape.broadcast %17976, %17973 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17975, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17970, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17979 = stablehlo.dot %17978, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %17980 = stablehlo.reshape %17935 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_6305 = tensor.dim %17950, %c0 : tensor + %17981 = arith.index_cast %dim_6305 : index to i64 + %from_elements_6306 = tensor.from_elements %17981, %c1_i64 : tensor<2xi64> + %17982 = stablehlo.dynamic_reshape %17950, %from_elements_6306 : (tensor, tensor<2xi64>) -> tensor + %dim_6307 = tensor.dim %17947, %c0 : tensor + %17983 = arith.index_cast %dim_6307 : index to i64 + %from_elements_6308 = tensor.from_elements %17983, %c1_i64 : tensor<2xi64> + %17984 = stablehlo.dynamic_reshape %17947, %from_elements_6308 : (tensor, tensor<2xi64>) -> tensor + %17985 = stablehlo.concatenate %17982, %17984, dim = 1 : (tensor, tensor) -> tensor + %17986 = "stablehlo.gather"(%17980, %17985) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %17987 = shape.shape_of %17979 : tensor -> tensor<2xindex> + %17988 = shape.shape_of %17986 : tensor -> tensor<2xindex> + %17989 = shape.cstr_broadcastable %17987, %17988 : tensor<2xindex>, tensor<2xindex> + %17990 = shape.assuming %17989 -> (tensor) { + %19688 = shape.broadcast %17987, %17988 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %17979, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %17986, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %17991 = shape.shape_of %17990 : tensor -> tensor<2xindex> + %17992 = stablehlo.dynamic_broadcast_in_dim %17990, %17991, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %17993 = stablehlo.dynamic_broadcast_in_dim %213, %17991, dims = [] : (tensor, tensor<2xindex>) -> tensor + %17994 = stablehlo.multiply %17992, %17993 : tensor + %dim_6309 = tensor.dim %17953, %c0 : tensor + %17995 = arith.index_cast %dim_6309 : index to i64 + %dim_6310 = tensor.dim %17990, %c0 : tensor + %17996 = arith.index_cast %dim_6310 : index to i64 + %17997 = arith.maxsi %17995, %17996 : i64 + %17998 = arith.index_cast %17997 : i64 to index + %from_elements_6311 = tensor.from_elements %17998, %c4096 : tensor<2xindex> + %17999 = stablehlo.dynamic_broadcast_in_dim %17953, %from_elements_6311, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6312 = tensor.dim %17999, %c0 : tensor + %18000 = arith.index_cast %dim_6312 : index to i64 + %from_elements_6313 = tensor.from_elements %18000, %c4096_i64 : tensor<2xi64> + %18001 = stablehlo.real_dynamic_slice %17994, %c_22, %from_elements_6313, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6314 = tensor.from_elements %18000, %c4096_i64, %c1_i64 : tensor<3xi64> + %18002 = stablehlo.dynamic_reshape %17999, %from_elements_6314 : (tensor, tensor<3xi64>) -> tensor + %18003 = stablehlo.dynamic_iota %from_elements_6314, dim = 1 : (tensor<3xi64>) -> tensor + %18004 = stablehlo.concatenate %18002, %18003, dim = 2 : (tensor, tensor) -> tensor + %18005 = "stablehlo.scatter"(%cst_2, %18004, %18001) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18006 = stablehlo.slice %17940 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18007 = stablehlo.reshape %18006 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18008 = stablehlo.custom_call @byteir.non_zero(%18007) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6315 = tensor.dim %18008, %c0 : tensor + %18009 = arith.index_cast %dim_6315 : index to i64 + %from_elements_6316 = tensor.from_elements %18009, %c1_i64 : tensor<2xi64> + %18010 = stablehlo.real_dynamic_slice %18008, %c_22, %from_elements_6316, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6317 = tensor.dim %18010, %c0 : tensor + %18011 = arith.index_cast %dim_6317 : index to i64 + %from_elements_6318 = tensor.from_elements %18011 : tensor<1xi64> + %18012 = stablehlo.dynamic_reshape %18010, %from_elements_6318 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6319 = tensor.from_elements %18009, %c2_i64 : tensor<2xi64> + %18013 = stablehlo.real_dynamic_slice %18008, %c_24, %from_elements_6319, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6320 = tensor.dim %18013, %c0 : tensor + %18014 = arith.index_cast %dim_6320 : index to i64 + %from_elements_6321 = tensor.from_elements %18014 : tensor<1xi64> + %18015 = stablehlo.dynamic_reshape %18013, %from_elements_6321 : (tensor, tensor<1xi64>) -> tensor + %dim_6322 = tensor.dim %18015, %c0 : tensor + %18016 = arith.index_cast %dim_6322 : index to i64 + %from_elements_6323 = tensor.from_elements %18016, %c1_i64 : tensor<2xi64> + %18017 = stablehlo.dynamic_reshape %18015, %from_elements_6323 : (tensor, tensor<2xi64>) -> tensor + %dim_6324 = tensor.dim %18017, %c0 : tensor + %18018 = arith.index_cast %dim_6324 : index to i64 + %from_elements_6325 = tensor.from_elements %c1_i64, %18018, %c4096_i64 : tensor<3xi64> + %18019 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6325, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6326 = tensor.dim %18019, %c1 : tensor<1x?x4096xi64> + %18020 = arith.index_cast %dim_6326 : index to i64 + %from_elements_6327 = tensor.from_elements %c1_i64, %18020, %c4096_i64, %c1_i64 : tensor<4xi64> + %18021 = stablehlo.dynamic_reshape %18019, %from_elements_6327 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18022 = stablehlo.dynamic_broadcast_in_dim %18017, %from_elements_6325, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6328 = tensor.dim %18022, %c1 : tensor<1x?x4096xi64> + %18023 = arith.index_cast %dim_6328 : index to i64 + %from_elements_6329 = tensor.from_elements %c1_i64, %18023, %c4096_i64, %c1_i64 : tensor<4xi64> + %18024 = stablehlo.dynamic_reshape %18022, %from_elements_6329 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18025 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6325, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6330 = tensor.dim %18025, %c1 : tensor<1x?x4096xi64> + %18026 = arith.index_cast %dim_6330 : index to i64 + %from_elements_6331 = tensor.from_elements %c1_i64, %18026, %c4096_i64, %c1_i64 : tensor<4xi64> + %18027 = stablehlo.dynamic_reshape %18025, %from_elements_6331 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18028 = stablehlo.concatenate %18021, %18024, %18027, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18029 = "stablehlo.gather"(%17951, %18028) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18030 = shape.shape_of %18029 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18031 = shape.num_elements %18030 : tensor<3xindex> -> index + %18032 = stablehlo.compute_reshape_shape %18031, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18033 = stablehlo.dynamic_reshape %18029, %18032 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18034 = stablehlo.dot %18033, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18035 = stablehlo.logistic %18034 : tensor + %18036 = shape.shape_of %18035 : tensor -> tensor<2xindex> + %18037 = shape.shape_of %18034 : tensor -> tensor<2xindex> + %18038 = shape.cstr_broadcastable %18036, %18037 : tensor<2xindex>, tensor<2xindex> + %18039 = shape.assuming %18038 -> (tensor) { + %19688 = shape.broadcast %18036, %18037 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18035, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18034, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18040 = shape.shape_of %18039 : tensor -> tensor<2xindex> + %18041 = shape.cstr_broadcastable %18040, %18037 : tensor<2xindex>, tensor<2xindex> + %18042 = shape.assuming %18041 -> (tensor) { + %19688 = shape.broadcast %18040, %18037 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18039, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18034, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18043 = stablehlo.dot %18042, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6332 = tensor.dim %18015, %c0 : tensor + %18044 = arith.index_cast %dim_6332 : index to i64 + %from_elements_6333 = tensor.from_elements %18044, %c1_i64 : tensor<2xi64> + %18045 = stablehlo.dynamic_reshape %18015, %from_elements_6333 : (tensor, tensor<2xi64>) -> tensor + %dim_6334 = tensor.dim %18012, %c0 : tensor + %18046 = arith.index_cast %dim_6334 : index to i64 + %from_elements_6335 = tensor.from_elements %18046, %c1_i64 : tensor<2xi64> + %18047 = stablehlo.dynamic_reshape %18012, %from_elements_6335 : (tensor, tensor<2xi64>) -> tensor + %18048 = stablehlo.concatenate %18045, %18047, dim = 1 : (tensor, tensor) -> tensor + %18049 = "stablehlo.gather"(%17980, %18048) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18050 = shape.shape_of %18043 : tensor -> tensor<2xindex> + %18051 = shape.shape_of %18049 : tensor -> tensor<2xindex> + %18052 = shape.cstr_broadcastable %18050, %18051 : tensor<2xindex>, tensor<2xindex> + %18053 = shape.assuming %18052 -> (tensor) { + %19688 = shape.broadcast %18050, %18051 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18043, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18049, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18054 = shape.shape_of %18053 : tensor -> tensor<2xindex> + %18055 = stablehlo.dynamic_broadcast_in_dim %18053, %18054, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18056 = stablehlo.dynamic_broadcast_in_dim %213, %18054, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18057 = stablehlo.multiply %18055, %18056 : tensor + %dim_6336 = tensor.dim %18017, %c0 : tensor + %18058 = arith.index_cast %dim_6336 : index to i64 + %dim_6337 = tensor.dim %18053, %c0 : tensor + %18059 = arith.index_cast %dim_6337 : index to i64 + %18060 = arith.maxsi %18058, %18059 : i64 + %18061 = arith.index_cast %18060 : i64 to index + %from_elements_6338 = tensor.from_elements %18061, %c4096 : tensor<2xindex> + %18062 = stablehlo.dynamic_broadcast_in_dim %18017, %from_elements_6338, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6339 = tensor.dim %18062, %c0 : tensor + %18063 = arith.index_cast %dim_6339 : index to i64 + %from_elements_6340 = tensor.from_elements %18063, %c4096_i64 : tensor<2xi64> + %18064 = stablehlo.real_dynamic_slice %18057, %c_22, %from_elements_6340, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6341 = tensor.from_elements %18063, %c4096_i64, %c1_i64 : tensor<3xi64> + %18065 = stablehlo.dynamic_reshape %18062, %from_elements_6341 : (tensor, tensor<3xi64>) -> tensor + %18066 = stablehlo.dynamic_iota %from_elements_6341, dim = 1 : (tensor<3xi64>) -> tensor + %18067 = stablehlo.concatenate %18065, %18066, dim = 2 : (tensor, tensor) -> tensor + %18068 = "stablehlo.scatter"(%18005, %18067, %18064) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18069 = stablehlo.slice %17940 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18070 = stablehlo.reshape %18069 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18071 = stablehlo.custom_call @byteir.non_zero(%18070) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6342 = tensor.dim %18071, %c0 : tensor + %18072 = arith.index_cast %dim_6342 : index to i64 + %from_elements_6343 = tensor.from_elements %18072, %c1_i64 : tensor<2xi64> + %18073 = stablehlo.real_dynamic_slice %18071, %c_22, %from_elements_6343, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6344 = tensor.dim %18073, %c0 : tensor + %18074 = arith.index_cast %dim_6344 : index to i64 + %from_elements_6345 = tensor.from_elements %18074 : tensor<1xi64> + %18075 = stablehlo.dynamic_reshape %18073, %from_elements_6345 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6346 = tensor.from_elements %18072, %c2_i64 : tensor<2xi64> + %18076 = stablehlo.real_dynamic_slice %18071, %c_24, %from_elements_6346, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6347 = tensor.dim %18076, %c0 : tensor + %18077 = arith.index_cast %dim_6347 : index to i64 + %from_elements_6348 = tensor.from_elements %18077 : tensor<1xi64> + %18078 = stablehlo.dynamic_reshape %18076, %from_elements_6348 : (tensor, tensor<1xi64>) -> tensor + %dim_6349 = tensor.dim %18078, %c0 : tensor + %18079 = arith.index_cast %dim_6349 : index to i64 + %from_elements_6350 = tensor.from_elements %18079, %c1_i64 : tensor<2xi64> + %18080 = stablehlo.dynamic_reshape %18078, %from_elements_6350 : (tensor, tensor<2xi64>) -> tensor + %dim_6351 = tensor.dim %18080, %c0 : tensor + %18081 = arith.index_cast %dim_6351 : index to i64 + %from_elements_6352 = tensor.from_elements %c1_i64, %18081, %c4096_i64 : tensor<3xi64> + %18082 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6352, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6353 = tensor.dim %18082, %c1 : tensor<1x?x4096xi64> + %18083 = arith.index_cast %dim_6353 : index to i64 + %from_elements_6354 = tensor.from_elements %c1_i64, %18083, %c4096_i64, %c1_i64 : tensor<4xi64> + %18084 = stablehlo.dynamic_reshape %18082, %from_elements_6354 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18085 = stablehlo.dynamic_broadcast_in_dim %18080, %from_elements_6352, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6355 = tensor.dim %18085, %c1 : tensor<1x?x4096xi64> + %18086 = arith.index_cast %dim_6355 : index to i64 + %from_elements_6356 = tensor.from_elements %c1_i64, %18086, %c4096_i64, %c1_i64 : tensor<4xi64> + %18087 = stablehlo.dynamic_reshape %18085, %from_elements_6356 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18088 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6352, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6357 = tensor.dim %18088, %c1 : tensor<1x?x4096xi64> + %18089 = arith.index_cast %dim_6357 : index to i64 + %from_elements_6358 = tensor.from_elements %c1_i64, %18089, %c4096_i64, %c1_i64 : tensor<4xi64> + %18090 = stablehlo.dynamic_reshape %18088, %from_elements_6358 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18091 = stablehlo.concatenate %18084, %18087, %18090, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18092 = "stablehlo.gather"(%17951, %18091) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18093 = shape.shape_of %18092 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18094 = shape.num_elements %18093 : tensor<3xindex> -> index + %18095 = stablehlo.compute_reshape_shape %18094, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18096 = stablehlo.dynamic_reshape %18092, %18095 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18097 = stablehlo.dot %18096, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18098 = stablehlo.logistic %18097 : tensor + %18099 = shape.shape_of %18098 : tensor -> tensor<2xindex> + %18100 = shape.shape_of %18097 : tensor -> tensor<2xindex> + %18101 = shape.cstr_broadcastable %18099, %18100 : tensor<2xindex>, tensor<2xindex> + %18102 = shape.assuming %18101 -> (tensor) { + %19688 = shape.broadcast %18099, %18100 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18098, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18097, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18103 = shape.shape_of %18102 : tensor -> tensor<2xindex> + %18104 = shape.cstr_broadcastable %18103, %18100 : tensor<2xindex>, tensor<2xindex> + %18105 = shape.assuming %18104 -> (tensor) { + %19688 = shape.broadcast %18103, %18100 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18102, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18097, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18106 = stablehlo.dot %18105, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6359 = tensor.dim %18078, %c0 : tensor + %18107 = arith.index_cast %dim_6359 : index to i64 + %from_elements_6360 = tensor.from_elements %18107, %c1_i64 : tensor<2xi64> + %18108 = stablehlo.dynamic_reshape %18078, %from_elements_6360 : (tensor, tensor<2xi64>) -> tensor + %dim_6361 = tensor.dim %18075, %c0 : tensor + %18109 = arith.index_cast %dim_6361 : index to i64 + %from_elements_6362 = tensor.from_elements %18109, %c1_i64 : tensor<2xi64> + %18110 = stablehlo.dynamic_reshape %18075, %from_elements_6362 : (tensor, tensor<2xi64>) -> tensor + %18111 = stablehlo.concatenate %18108, %18110, dim = 1 : (tensor, tensor) -> tensor + %18112 = "stablehlo.gather"(%17980, %18111) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18113 = shape.shape_of %18106 : tensor -> tensor<2xindex> + %18114 = shape.shape_of %18112 : tensor -> tensor<2xindex> + %18115 = shape.cstr_broadcastable %18113, %18114 : tensor<2xindex>, tensor<2xindex> + %18116 = shape.assuming %18115 -> (tensor) { + %19688 = shape.broadcast %18113, %18114 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18106, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18112, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18117 = shape.shape_of %18116 : tensor -> tensor<2xindex> + %18118 = stablehlo.dynamic_broadcast_in_dim %18116, %18117, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18119 = stablehlo.dynamic_broadcast_in_dim %213, %18117, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18120 = stablehlo.multiply %18118, %18119 : tensor + %dim_6363 = tensor.dim %18080, %c0 : tensor + %18121 = arith.index_cast %dim_6363 : index to i64 + %dim_6364 = tensor.dim %18116, %c0 : tensor + %18122 = arith.index_cast %dim_6364 : index to i64 + %18123 = arith.maxsi %18121, %18122 : i64 + %18124 = arith.index_cast %18123 : i64 to index + %from_elements_6365 = tensor.from_elements %18124, %c4096 : tensor<2xindex> + %18125 = stablehlo.dynamic_broadcast_in_dim %18080, %from_elements_6365, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6366 = tensor.dim %18125, %c0 : tensor + %18126 = arith.index_cast %dim_6366 : index to i64 + %from_elements_6367 = tensor.from_elements %18126, %c4096_i64 : tensor<2xi64> + %18127 = stablehlo.real_dynamic_slice %18120, %c_22, %from_elements_6367, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6368 = tensor.from_elements %18126, %c4096_i64, %c1_i64 : tensor<3xi64> + %18128 = stablehlo.dynamic_reshape %18125, %from_elements_6368 : (tensor, tensor<3xi64>) -> tensor + %18129 = stablehlo.dynamic_iota %from_elements_6368, dim = 1 : (tensor<3xi64>) -> tensor + %18130 = stablehlo.concatenate %18128, %18129, dim = 2 : (tensor, tensor) -> tensor + %18131 = "stablehlo.scatter"(%18068, %18130, %18127) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18132 = stablehlo.slice %17940 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18133 = stablehlo.reshape %18132 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18134 = stablehlo.custom_call @byteir.non_zero(%18133) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6369 = tensor.dim %18134, %c0 : tensor + %18135 = arith.index_cast %dim_6369 : index to i64 + %from_elements_6370 = tensor.from_elements %18135, %c1_i64 : tensor<2xi64> + %18136 = stablehlo.real_dynamic_slice %18134, %c_22, %from_elements_6370, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6371 = tensor.dim %18136, %c0 : tensor + %18137 = arith.index_cast %dim_6371 : index to i64 + %from_elements_6372 = tensor.from_elements %18137 : tensor<1xi64> + %18138 = stablehlo.dynamic_reshape %18136, %from_elements_6372 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6373 = tensor.from_elements %18135, %c2_i64 : tensor<2xi64> + %18139 = stablehlo.real_dynamic_slice %18134, %c_24, %from_elements_6373, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6374 = tensor.dim %18139, %c0 : tensor + %18140 = arith.index_cast %dim_6374 : index to i64 + %from_elements_6375 = tensor.from_elements %18140 : tensor<1xi64> + %18141 = stablehlo.dynamic_reshape %18139, %from_elements_6375 : (tensor, tensor<1xi64>) -> tensor + %dim_6376 = tensor.dim %18141, %c0 : tensor + %18142 = arith.index_cast %dim_6376 : index to i64 + %from_elements_6377 = tensor.from_elements %18142, %c1_i64 : tensor<2xi64> + %18143 = stablehlo.dynamic_reshape %18141, %from_elements_6377 : (tensor, tensor<2xi64>) -> tensor + %dim_6378 = tensor.dim %18143, %c0 : tensor + %18144 = arith.index_cast %dim_6378 : index to i64 + %from_elements_6379 = tensor.from_elements %c1_i64, %18144, %c4096_i64 : tensor<3xi64> + %18145 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6379, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6380 = tensor.dim %18145, %c1 : tensor<1x?x4096xi64> + %18146 = arith.index_cast %dim_6380 : index to i64 + %from_elements_6381 = tensor.from_elements %c1_i64, %18146, %c4096_i64, %c1_i64 : tensor<4xi64> + %18147 = stablehlo.dynamic_reshape %18145, %from_elements_6381 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18148 = stablehlo.dynamic_broadcast_in_dim %18143, %from_elements_6379, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6382 = tensor.dim %18148, %c1 : tensor<1x?x4096xi64> + %18149 = arith.index_cast %dim_6382 : index to i64 + %from_elements_6383 = tensor.from_elements %c1_i64, %18149, %c4096_i64, %c1_i64 : tensor<4xi64> + %18150 = stablehlo.dynamic_reshape %18148, %from_elements_6383 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18151 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6379, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6384 = tensor.dim %18151, %c1 : tensor<1x?x4096xi64> + %18152 = arith.index_cast %dim_6384 : index to i64 + %from_elements_6385 = tensor.from_elements %c1_i64, %18152, %c4096_i64, %c1_i64 : tensor<4xi64> + %18153 = stablehlo.dynamic_reshape %18151, %from_elements_6385 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18154 = stablehlo.concatenate %18147, %18150, %18153, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18155 = "stablehlo.gather"(%17951, %18154) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18156 = shape.shape_of %18155 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18157 = shape.num_elements %18156 : tensor<3xindex> -> index + %18158 = stablehlo.compute_reshape_shape %18157, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18159 = stablehlo.dynamic_reshape %18155, %18158 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18160 = stablehlo.dot %18159, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18161 = stablehlo.logistic %18160 : tensor + %18162 = shape.shape_of %18161 : tensor -> tensor<2xindex> + %18163 = shape.shape_of %18160 : tensor -> tensor<2xindex> + %18164 = shape.cstr_broadcastable %18162, %18163 : tensor<2xindex>, tensor<2xindex> + %18165 = shape.assuming %18164 -> (tensor) { + %19688 = shape.broadcast %18162, %18163 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18161, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18160, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18166 = shape.shape_of %18165 : tensor -> tensor<2xindex> + %18167 = shape.cstr_broadcastable %18166, %18163 : tensor<2xindex>, tensor<2xindex> + %18168 = shape.assuming %18167 -> (tensor) { + %19688 = shape.broadcast %18166, %18163 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18165, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18160, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18169 = stablehlo.dot %18168, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6386 = tensor.dim %18141, %c0 : tensor + %18170 = arith.index_cast %dim_6386 : index to i64 + %from_elements_6387 = tensor.from_elements %18170, %c1_i64 : tensor<2xi64> + %18171 = stablehlo.dynamic_reshape %18141, %from_elements_6387 : (tensor, tensor<2xi64>) -> tensor + %dim_6388 = tensor.dim %18138, %c0 : tensor + %18172 = arith.index_cast %dim_6388 : index to i64 + %from_elements_6389 = tensor.from_elements %18172, %c1_i64 : tensor<2xi64> + %18173 = stablehlo.dynamic_reshape %18138, %from_elements_6389 : (tensor, tensor<2xi64>) -> tensor + %18174 = stablehlo.concatenate %18171, %18173, dim = 1 : (tensor, tensor) -> tensor + %18175 = "stablehlo.gather"(%17980, %18174) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18176 = shape.shape_of %18169 : tensor -> tensor<2xindex> + %18177 = shape.shape_of %18175 : tensor -> tensor<2xindex> + %18178 = shape.cstr_broadcastable %18176, %18177 : tensor<2xindex>, tensor<2xindex> + %18179 = shape.assuming %18178 -> (tensor) { + %19688 = shape.broadcast %18176, %18177 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18169, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18175, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18180 = shape.shape_of %18179 : tensor -> tensor<2xindex> + %18181 = stablehlo.dynamic_broadcast_in_dim %18179, %18180, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18182 = stablehlo.dynamic_broadcast_in_dim %213, %18180, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18183 = stablehlo.multiply %18181, %18182 : tensor + %dim_6390 = tensor.dim %18143, %c0 : tensor + %18184 = arith.index_cast %dim_6390 : index to i64 + %dim_6391 = tensor.dim %18179, %c0 : tensor + %18185 = arith.index_cast %dim_6391 : index to i64 + %18186 = arith.maxsi %18184, %18185 : i64 + %18187 = arith.index_cast %18186 : i64 to index + %from_elements_6392 = tensor.from_elements %18187, %c4096 : tensor<2xindex> + %18188 = stablehlo.dynamic_broadcast_in_dim %18143, %from_elements_6392, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6393 = tensor.dim %18188, %c0 : tensor + %18189 = arith.index_cast %dim_6393 : index to i64 + %from_elements_6394 = tensor.from_elements %18189, %c4096_i64 : tensor<2xi64> + %18190 = stablehlo.real_dynamic_slice %18183, %c_22, %from_elements_6394, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6395 = tensor.from_elements %18189, %c4096_i64, %c1_i64 : tensor<3xi64> + %18191 = stablehlo.dynamic_reshape %18188, %from_elements_6395 : (tensor, tensor<3xi64>) -> tensor + %18192 = stablehlo.dynamic_iota %from_elements_6395, dim = 1 : (tensor<3xi64>) -> tensor + %18193 = stablehlo.concatenate %18191, %18192, dim = 2 : (tensor, tensor) -> tensor + %18194 = "stablehlo.scatter"(%18131, %18193, %18190) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18195 = stablehlo.slice %17940 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18196 = stablehlo.reshape %18195 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18197 = stablehlo.custom_call @byteir.non_zero(%18196) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6396 = tensor.dim %18197, %c0 : tensor + %18198 = arith.index_cast %dim_6396 : index to i64 + %from_elements_6397 = tensor.from_elements %18198, %c1_i64 : tensor<2xi64> + %18199 = stablehlo.real_dynamic_slice %18197, %c_22, %from_elements_6397, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6398 = tensor.dim %18199, %c0 : tensor + %18200 = arith.index_cast %dim_6398 : index to i64 + %from_elements_6399 = tensor.from_elements %18200 : tensor<1xi64> + %18201 = stablehlo.dynamic_reshape %18199, %from_elements_6399 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6400 = tensor.from_elements %18198, %c2_i64 : tensor<2xi64> + %18202 = stablehlo.real_dynamic_slice %18197, %c_24, %from_elements_6400, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6401 = tensor.dim %18202, %c0 : tensor + %18203 = arith.index_cast %dim_6401 : index to i64 + %from_elements_6402 = tensor.from_elements %18203 : tensor<1xi64> + %18204 = stablehlo.dynamic_reshape %18202, %from_elements_6402 : (tensor, tensor<1xi64>) -> tensor + %dim_6403 = tensor.dim %18204, %c0 : tensor + %18205 = arith.index_cast %dim_6403 : index to i64 + %from_elements_6404 = tensor.from_elements %18205, %c1_i64 : tensor<2xi64> + %18206 = stablehlo.dynamic_reshape %18204, %from_elements_6404 : (tensor, tensor<2xi64>) -> tensor + %dim_6405 = tensor.dim %18206, %c0 : tensor + %18207 = arith.index_cast %dim_6405 : index to i64 + %from_elements_6406 = tensor.from_elements %c1_i64, %18207, %c4096_i64 : tensor<3xi64> + %18208 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6406, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6407 = tensor.dim %18208, %c1 : tensor<1x?x4096xi64> + %18209 = arith.index_cast %dim_6407 : index to i64 + %from_elements_6408 = tensor.from_elements %c1_i64, %18209, %c4096_i64, %c1_i64 : tensor<4xi64> + %18210 = stablehlo.dynamic_reshape %18208, %from_elements_6408 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18211 = stablehlo.dynamic_broadcast_in_dim %18206, %from_elements_6406, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6409 = tensor.dim %18211, %c1 : tensor<1x?x4096xi64> + %18212 = arith.index_cast %dim_6409 : index to i64 + %from_elements_6410 = tensor.from_elements %c1_i64, %18212, %c4096_i64, %c1_i64 : tensor<4xi64> + %18213 = stablehlo.dynamic_reshape %18211, %from_elements_6410 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18214 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6406, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6411 = tensor.dim %18214, %c1 : tensor<1x?x4096xi64> + %18215 = arith.index_cast %dim_6411 : index to i64 + %from_elements_6412 = tensor.from_elements %c1_i64, %18215, %c4096_i64, %c1_i64 : tensor<4xi64> + %18216 = stablehlo.dynamic_reshape %18214, %from_elements_6412 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18217 = stablehlo.concatenate %18210, %18213, %18216, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18218 = "stablehlo.gather"(%17951, %18217) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18219 = shape.shape_of %18218 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18220 = shape.num_elements %18219 : tensor<3xindex> -> index + %18221 = stablehlo.compute_reshape_shape %18220, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18222 = stablehlo.dynamic_reshape %18218, %18221 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18223 = stablehlo.dot %18222, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18224 = stablehlo.logistic %18223 : tensor + %18225 = shape.shape_of %18224 : tensor -> tensor<2xindex> + %18226 = shape.shape_of %18223 : tensor -> tensor<2xindex> + %18227 = shape.cstr_broadcastable %18225, %18226 : tensor<2xindex>, tensor<2xindex> + %18228 = shape.assuming %18227 -> (tensor) { + %19688 = shape.broadcast %18225, %18226 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18224, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18223, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18229 = shape.shape_of %18228 : tensor -> tensor<2xindex> + %18230 = shape.cstr_broadcastable %18229, %18226 : tensor<2xindex>, tensor<2xindex> + %18231 = shape.assuming %18230 -> (tensor) { + %19688 = shape.broadcast %18229, %18226 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18228, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18223, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18232 = stablehlo.dot %18231, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6413 = tensor.dim %18204, %c0 : tensor + %18233 = arith.index_cast %dim_6413 : index to i64 + %from_elements_6414 = tensor.from_elements %18233, %c1_i64 : tensor<2xi64> + %18234 = stablehlo.dynamic_reshape %18204, %from_elements_6414 : (tensor, tensor<2xi64>) -> tensor + %dim_6415 = tensor.dim %18201, %c0 : tensor + %18235 = arith.index_cast %dim_6415 : index to i64 + %from_elements_6416 = tensor.from_elements %18235, %c1_i64 : tensor<2xi64> + %18236 = stablehlo.dynamic_reshape %18201, %from_elements_6416 : (tensor, tensor<2xi64>) -> tensor + %18237 = stablehlo.concatenate %18234, %18236, dim = 1 : (tensor, tensor) -> tensor + %18238 = "stablehlo.gather"(%17980, %18237) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18239 = shape.shape_of %18232 : tensor -> tensor<2xindex> + %18240 = shape.shape_of %18238 : tensor -> tensor<2xindex> + %18241 = shape.cstr_broadcastable %18239, %18240 : tensor<2xindex>, tensor<2xindex> + %18242 = shape.assuming %18241 -> (tensor) { + %19688 = shape.broadcast %18239, %18240 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18232, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18238, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18243 = shape.shape_of %18242 : tensor -> tensor<2xindex> + %18244 = stablehlo.dynamic_broadcast_in_dim %18242, %18243, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18245 = stablehlo.dynamic_broadcast_in_dim %213, %18243, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18246 = stablehlo.multiply %18244, %18245 : tensor + %dim_6417 = tensor.dim %18206, %c0 : tensor + %18247 = arith.index_cast %dim_6417 : index to i64 + %dim_6418 = tensor.dim %18242, %c0 : tensor + %18248 = arith.index_cast %dim_6418 : index to i64 + %18249 = arith.maxsi %18247, %18248 : i64 + %18250 = arith.index_cast %18249 : i64 to index + %from_elements_6419 = tensor.from_elements %18250, %c4096 : tensor<2xindex> + %18251 = stablehlo.dynamic_broadcast_in_dim %18206, %from_elements_6419, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6420 = tensor.dim %18251, %c0 : tensor + %18252 = arith.index_cast %dim_6420 : index to i64 + %from_elements_6421 = tensor.from_elements %18252, %c4096_i64 : tensor<2xi64> + %18253 = stablehlo.real_dynamic_slice %18246, %c_22, %from_elements_6421, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6422 = tensor.from_elements %18252, %c4096_i64, %c1_i64 : tensor<3xi64> + %18254 = stablehlo.dynamic_reshape %18251, %from_elements_6422 : (tensor, tensor<3xi64>) -> tensor + %18255 = stablehlo.dynamic_iota %from_elements_6422, dim = 1 : (tensor<3xi64>) -> tensor + %18256 = stablehlo.concatenate %18254, %18255, dim = 2 : (tensor, tensor) -> tensor + %18257 = "stablehlo.scatter"(%18194, %18256, %18253) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18258 = stablehlo.slice %17940 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18259 = stablehlo.reshape %18258 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18260 = stablehlo.custom_call @byteir.non_zero(%18259) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6423 = tensor.dim %18260, %c0 : tensor + %18261 = arith.index_cast %dim_6423 : index to i64 + %from_elements_6424 = tensor.from_elements %18261, %c1_i64 : tensor<2xi64> + %18262 = stablehlo.real_dynamic_slice %18260, %c_22, %from_elements_6424, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6425 = tensor.dim %18262, %c0 : tensor + %18263 = arith.index_cast %dim_6425 : index to i64 + %from_elements_6426 = tensor.from_elements %18263 : tensor<1xi64> + %18264 = stablehlo.dynamic_reshape %18262, %from_elements_6426 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6427 = tensor.from_elements %18261, %c2_i64 : tensor<2xi64> + %18265 = stablehlo.real_dynamic_slice %18260, %c_24, %from_elements_6427, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6428 = tensor.dim %18265, %c0 : tensor + %18266 = arith.index_cast %dim_6428 : index to i64 + %from_elements_6429 = tensor.from_elements %18266 : tensor<1xi64> + %18267 = stablehlo.dynamic_reshape %18265, %from_elements_6429 : (tensor, tensor<1xi64>) -> tensor + %dim_6430 = tensor.dim %18267, %c0 : tensor + %18268 = arith.index_cast %dim_6430 : index to i64 + %from_elements_6431 = tensor.from_elements %18268, %c1_i64 : tensor<2xi64> + %18269 = stablehlo.dynamic_reshape %18267, %from_elements_6431 : (tensor, tensor<2xi64>) -> tensor + %dim_6432 = tensor.dim %18269, %c0 : tensor + %18270 = arith.index_cast %dim_6432 : index to i64 + %from_elements_6433 = tensor.from_elements %c1_i64, %18270, %c4096_i64 : tensor<3xi64> + %18271 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6433, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6434 = tensor.dim %18271, %c1 : tensor<1x?x4096xi64> + %18272 = arith.index_cast %dim_6434 : index to i64 + %from_elements_6435 = tensor.from_elements %c1_i64, %18272, %c4096_i64, %c1_i64 : tensor<4xi64> + %18273 = stablehlo.dynamic_reshape %18271, %from_elements_6435 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18274 = stablehlo.dynamic_broadcast_in_dim %18269, %from_elements_6433, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6436 = tensor.dim %18274, %c1 : tensor<1x?x4096xi64> + %18275 = arith.index_cast %dim_6436 : index to i64 + %from_elements_6437 = tensor.from_elements %c1_i64, %18275, %c4096_i64, %c1_i64 : tensor<4xi64> + %18276 = stablehlo.dynamic_reshape %18274, %from_elements_6437 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18277 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6433, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6438 = tensor.dim %18277, %c1 : tensor<1x?x4096xi64> + %18278 = arith.index_cast %dim_6438 : index to i64 + %from_elements_6439 = tensor.from_elements %c1_i64, %18278, %c4096_i64, %c1_i64 : tensor<4xi64> + %18279 = stablehlo.dynamic_reshape %18277, %from_elements_6439 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18280 = stablehlo.concatenate %18273, %18276, %18279, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18281 = "stablehlo.gather"(%17951, %18280) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18282 = shape.shape_of %18281 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18283 = shape.num_elements %18282 : tensor<3xindex> -> index + %18284 = stablehlo.compute_reshape_shape %18283, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18285 = stablehlo.dynamic_reshape %18281, %18284 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18286 = stablehlo.dot %18285, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18287 = stablehlo.logistic %18286 : tensor + %18288 = shape.shape_of %18287 : tensor -> tensor<2xindex> + %18289 = shape.shape_of %18286 : tensor -> tensor<2xindex> + %18290 = shape.cstr_broadcastable %18288, %18289 : tensor<2xindex>, tensor<2xindex> + %18291 = shape.assuming %18290 -> (tensor) { + %19688 = shape.broadcast %18288, %18289 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18287, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18286, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18292 = shape.shape_of %18291 : tensor -> tensor<2xindex> + %18293 = shape.cstr_broadcastable %18292, %18289 : tensor<2xindex>, tensor<2xindex> + %18294 = shape.assuming %18293 -> (tensor) { + %19688 = shape.broadcast %18292, %18289 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18291, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18286, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18295 = stablehlo.dot %18294, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6440 = tensor.dim %18267, %c0 : tensor + %18296 = arith.index_cast %dim_6440 : index to i64 + %from_elements_6441 = tensor.from_elements %18296, %c1_i64 : tensor<2xi64> + %18297 = stablehlo.dynamic_reshape %18267, %from_elements_6441 : (tensor, tensor<2xi64>) -> tensor + %dim_6442 = tensor.dim %18264, %c0 : tensor + %18298 = arith.index_cast %dim_6442 : index to i64 + %from_elements_6443 = tensor.from_elements %18298, %c1_i64 : tensor<2xi64> + %18299 = stablehlo.dynamic_reshape %18264, %from_elements_6443 : (tensor, tensor<2xi64>) -> tensor + %18300 = stablehlo.concatenate %18297, %18299, dim = 1 : (tensor, tensor) -> tensor + %18301 = "stablehlo.gather"(%17980, %18300) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18302 = shape.shape_of %18295 : tensor -> tensor<2xindex> + %18303 = shape.shape_of %18301 : tensor -> tensor<2xindex> + %18304 = shape.cstr_broadcastable %18302, %18303 : tensor<2xindex>, tensor<2xindex> + %18305 = shape.assuming %18304 -> (tensor) { + %19688 = shape.broadcast %18302, %18303 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18295, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18301, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18306 = shape.shape_of %18305 : tensor -> tensor<2xindex> + %18307 = stablehlo.dynamic_broadcast_in_dim %18305, %18306, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18308 = stablehlo.dynamic_broadcast_in_dim %213, %18306, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18309 = stablehlo.multiply %18307, %18308 : tensor + %dim_6444 = tensor.dim %18269, %c0 : tensor + %18310 = arith.index_cast %dim_6444 : index to i64 + %dim_6445 = tensor.dim %18305, %c0 : tensor + %18311 = arith.index_cast %dim_6445 : index to i64 + %18312 = arith.maxsi %18310, %18311 : i64 + %18313 = arith.index_cast %18312 : i64 to index + %from_elements_6446 = tensor.from_elements %18313, %c4096 : tensor<2xindex> + %18314 = stablehlo.dynamic_broadcast_in_dim %18269, %from_elements_6446, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6447 = tensor.dim %18314, %c0 : tensor + %18315 = arith.index_cast %dim_6447 : index to i64 + %from_elements_6448 = tensor.from_elements %18315, %c4096_i64 : tensor<2xi64> + %18316 = stablehlo.real_dynamic_slice %18309, %c_22, %from_elements_6448, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6449 = tensor.from_elements %18315, %c4096_i64, %c1_i64 : tensor<3xi64> + %18317 = stablehlo.dynamic_reshape %18314, %from_elements_6449 : (tensor, tensor<3xi64>) -> tensor + %18318 = stablehlo.dynamic_iota %from_elements_6449, dim = 1 : (tensor<3xi64>) -> tensor + %18319 = stablehlo.concatenate %18317, %18318, dim = 2 : (tensor, tensor) -> tensor + %18320 = "stablehlo.scatter"(%18257, %18319, %18316) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18321 = stablehlo.slice %17940 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18322 = stablehlo.reshape %18321 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18323 = stablehlo.custom_call @byteir.non_zero(%18322) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6450 = tensor.dim %18323, %c0 : tensor + %18324 = arith.index_cast %dim_6450 : index to i64 + %from_elements_6451 = tensor.from_elements %18324, %c1_i64 : tensor<2xi64> + %18325 = stablehlo.real_dynamic_slice %18323, %c_22, %from_elements_6451, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6452 = tensor.dim %18325, %c0 : tensor + %18326 = arith.index_cast %dim_6452 : index to i64 + %from_elements_6453 = tensor.from_elements %18326 : tensor<1xi64> + %18327 = stablehlo.dynamic_reshape %18325, %from_elements_6453 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6454 = tensor.from_elements %18324, %c2_i64 : tensor<2xi64> + %18328 = stablehlo.real_dynamic_slice %18323, %c_24, %from_elements_6454, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6455 = tensor.dim %18328, %c0 : tensor + %18329 = arith.index_cast %dim_6455 : index to i64 + %from_elements_6456 = tensor.from_elements %18329 : tensor<1xi64> + %18330 = stablehlo.dynamic_reshape %18328, %from_elements_6456 : (tensor, tensor<1xi64>) -> tensor + %dim_6457 = tensor.dim %18330, %c0 : tensor + %18331 = arith.index_cast %dim_6457 : index to i64 + %from_elements_6458 = tensor.from_elements %18331, %c1_i64 : tensor<2xi64> + %18332 = stablehlo.dynamic_reshape %18330, %from_elements_6458 : (tensor, tensor<2xi64>) -> tensor + %dim_6459 = tensor.dim %18332, %c0 : tensor + %18333 = arith.index_cast %dim_6459 : index to i64 + %from_elements_6460 = tensor.from_elements %c1_i64, %18333, %c4096_i64 : tensor<3xi64> + %18334 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6460, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6461 = tensor.dim %18334, %c1 : tensor<1x?x4096xi64> + %18335 = arith.index_cast %dim_6461 : index to i64 + %from_elements_6462 = tensor.from_elements %c1_i64, %18335, %c4096_i64, %c1_i64 : tensor<4xi64> + %18336 = stablehlo.dynamic_reshape %18334, %from_elements_6462 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18337 = stablehlo.dynamic_broadcast_in_dim %18332, %from_elements_6460, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6463 = tensor.dim %18337, %c1 : tensor<1x?x4096xi64> + %18338 = arith.index_cast %dim_6463 : index to i64 + %from_elements_6464 = tensor.from_elements %c1_i64, %18338, %c4096_i64, %c1_i64 : tensor<4xi64> + %18339 = stablehlo.dynamic_reshape %18337, %from_elements_6464 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18340 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6460, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6465 = tensor.dim %18340, %c1 : tensor<1x?x4096xi64> + %18341 = arith.index_cast %dim_6465 : index to i64 + %from_elements_6466 = tensor.from_elements %c1_i64, %18341, %c4096_i64, %c1_i64 : tensor<4xi64> + %18342 = stablehlo.dynamic_reshape %18340, %from_elements_6466 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18343 = stablehlo.concatenate %18336, %18339, %18342, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18344 = "stablehlo.gather"(%17951, %18343) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18345 = shape.shape_of %18344 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18346 = shape.num_elements %18345 : tensor<3xindex> -> index + %18347 = stablehlo.compute_reshape_shape %18346, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18348 = stablehlo.dynamic_reshape %18344, %18347 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18349 = stablehlo.dot %18348, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18350 = stablehlo.logistic %18349 : tensor + %18351 = shape.shape_of %18350 : tensor -> tensor<2xindex> + %18352 = shape.shape_of %18349 : tensor -> tensor<2xindex> + %18353 = shape.cstr_broadcastable %18351, %18352 : tensor<2xindex>, tensor<2xindex> + %18354 = shape.assuming %18353 -> (tensor) { + %19688 = shape.broadcast %18351, %18352 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18350, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18349, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18355 = shape.shape_of %18354 : tensor -> tensor<2xindex> + %18356 = shape.cstr_broadcastable %18355, %18352 : tensor<2xindex>, tensor<2xindex> + %18357 = shape.assuming %18356 -> (tensor) { + %19688 = shape.broadcast %18355, %18352 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18354, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18349, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18358 = stablehlo.dot %18357, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6467 = tensor.dim %18330, %c0 : tensor + %18359 = arith.index_cast %dim_6467 : index to i64 + %from_elements_6468 = tensor.from_elements %18359, %c1_i64 : tensor<2xi64> + %18360 = stablehlo.dynamic_reshape %18330, %from_elements_6468 : (tensor, tensor<2xi64>) -> tensor + %dim_6469 = tensor.dim %18327, %c0 : tensor + %18361 = arith.index_cast %dim_6469 : index to i64 + %from_elements_6470 = tensor.from_elements %18361, %c1_i64 : tensor<2xi64> + %18362 = stablehlo.dynamic_reshape %18327, %from_elements_6470 : (tensor, tensor<2xi64>) -> tensor + %18363 = stablehlo.concatenate %18360, %18362, dim = 1 : (tensor, tensor) -> tensor + %18364 = "stablehlo.gather"(%17980, %18363) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18365 = shape.shape_of %18358 : tensor -> tensor<2xindex> + %18366 = shape.shape_of %18364 : tensor -> tensor<2xindex> + %18367 = shape.cstr_broadcastable %18365, %18366 : tensor<2xindex>, tensor<2xindex> + %18368 = shape.assuming %18367 -> (tensor) { + %19688 = shape.broadcast %18365, %18366 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18358, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18364, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18369 = shape.shape_of %18368 : tensor -> tensor<2xindex> + %18370 = stablehlo.dynamic_broadcast_in_dim %18368, %18369, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18371 = stablehlo.dynamic_broadcast_in_dim %213, %18369, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18372 = stablehlo.multiply %18370, %18371 : tensor + %dim_6471 = tensor.dim %18332, %c0 : tensor + %18373 = arith.index_cast %dim_6471 : index to i64 + %dim_6472 = tensor.dim %18368, %c0 : tensor + %18374 = arith.index_cast %dim_6472 : index to i64 + %18375 = arith.maxsi %18373, %18374 : i64 + %18376 = arith.index_cast %18375 : i64 to index + %from_elements_6473 = tensor.from_elements %18376, %c4096 : tensor<2xindex> + %18377 = stablehlo.dynamic_broadcast_in_dim %18332, %from_elements_6473, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6474 = tensor.dim %18377, %c0 : tensor + %18378 = arith.index_cast %dim_6474 : index to i64 + %from_elements_6475 = tensor.from_elements %18378, %c4096_i64 : tensor<2xi64> + %18379 = stablehlo.real_dynamic_slice %18372, %c_22, %from_elements_6475, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6476 = tensor.from_elements %18378, %c4096_i64, %c1_i64 : tensor<3xi64> + %18380 = stablehlo.dynamic_reshape %18377, %from_elements_6476 : (tensor, tensor<3xi64>) -> tensor + %18381 = stablehlo.dynamic_iota %from_elements_6476, dim = 1 : (tensor<3xi64>) -> tensor + %18382 = stablehlo.concatenate %18380, %18381, dim = 2 : (tensor, tensor) -> tensor + %18383 = "stablehlo.scatter"(%18320, %18382, %18379) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18384 = stablehlo.slice %17940 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18385 = stablehlo.reshape %18384 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18386 = stablehlo.custom_call @byteir.non_zero(%18385) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6477 = tensor.dim %18386, %c0 : tensor + %18387 = arith.index_cast %dim_6477 : index to i64 + %from_elements_6478 = tensor.from_elements %18387, %c1_i64 : tensor<2xi64> + %18388 = stablehlo.real_dynamic_slice %18386, %c_22, %from_elements_6478, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6479 = tensor.dim %18388, %c0 : tensor + %18389 = arith.index_cast %dim_6479 : index to i64 + %from_elements_6480 = tensor.from_elements %18389 : tensor<1xi64> + %18390 = stablehlo.dynamic_reshape %18388, %from_elements_6480 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6481 = tensor.from_elements %18387, %c2_i64 : tensor<2xi64> + %18391 = stablehlo.real_dynamic_slice %18386, %c_24, %from_elements_6481, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6482 = tensor.dim %18391, %c0 : tensor + %18392 = arith.index_cast %dim_6482 : index to i64 + %from_elements_6483 = tensor.from_elements %18392 : tensor<1xi64> + %18393 = stablehlo.dynamic_reshape %18391, %from_elements_6483 : (tensor, tensor<1xi64>) -> tensor + %dim_6484 = tensor.dim %18393, %c0 : tensor + %18394 = arith.index_cast %dim_6484 : index to i64 + %from_elements_6485 = tensor.from_elements %18394, %c1_i64 : tensor<2xi64> + %18395 = stablehlo.dynamic_reshape %18393, %from_elements_6485 : (tensor, tensor<2xi64>) -> tensor + %dim_6486 = tensor.dim %18395, %c0 : tensor + %18396 = arith.index_cast %dim_6486 : index to i64 + %from_elements_6487 = tensor.from_elements %c1_i64, %18396, %c4096_i64 : tensor<3xi64> + %18397 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6487, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6488 = tensor.dim %18397, %c1 : tensor<1x?x4096xi64> + %18398 = arith.index_cast %dim_6488 : index to i64 + %from_elements_6489 = tensor.from_elements %c1_i64, %18398, %c4096_i64, %c1_i64 : tensor<4xi64> + %18399 = stablehlo.dynamic_reshape %18397, %from_elements_6489 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18400 = stablehlo.dynamic_broadcast_in_dim %18395, %from_elements_6487, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6490 = tensor.dim %18400, %c1 : tensor<1x?x4096xi64> + %18401 = arith.index_cast %dim_6490 : index to i64 + %from_elements_6491 = tensor.from_elements %c1_i64, %18401, %c4096_i64, %c1_i64 : tensor<4xi64> + %18402 = stablehlo.dynamic_reshape %18400, %from_elements_6491 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18403 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6487, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6492 = tensor.dim %18403, %c1 : tensor<1x?x4096xi64> + %18404 = arith.index_cast %dim_6492 : index to i64 + %from_elements_6493 = tensor.from_elements %c1_i64, %18404, %c4096_i64, %c1_i64 : tensor<4xi64> + %18405 = stablehlo.dynamic_reshape %18403, %from_elements_6493 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18406 = stablehlo.concatenate %18399, %18402, %18405, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18407 = "stablehlo.gather"(%17951, %18406) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18408 = shape.shape_of %18407 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18409 = shape.num_elements %18408 : tensor<3xindex> -> index + %18410 = stablehlo.compute_reshape_shape %18409, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18411 = stablehlo.dynamic_reshape %18407, %18410 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18412 = stablehlo.dot %18411, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18413 = stablehlo.logistic %18412 : tensor + %18414 = shape.shape_of %18413 : tensor -> tensor<2xindex> + %18415 = shape.shape_of %18412 : tensor -> tensor<2xindex> + %18416 = shape.cstr_broadcastable %18414, %18415 : tensor<2xindex>, tensor<2xindex> + %18417 = shape.assuming %18416 -> (tensor) { + %19688 = shape.broadcast %18414, %18415 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18413, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18412, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18418 = shape.shape_of %18417 : tensor -> tensor<2xindex> + %18419 = shape.cstr_broadcastable %18418, %18415 : tensor<2xindex>, tensor<2xindex> + %18420 = shape.assuming %18419 -> (tensor) { + %19688 = shape.broadcast %18418, %18415 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18417, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18412, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18421 = stablehlo.dot %18420, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6494 = tensor.dim %18393, %c0 : tensor + %18422 = arith.index_cast %dim_6494 : index to i64 + %from_elements_6495 = tensor.from_elements %18422, %c1_i64 : tensor<2xi64> + %18423 = stablehlo.dynamic_reshape %18393, %from_elements_6495 : (tensor, tensor<2xi64>) -> tensor + %dim_6496 = tensor.dim %18390, %c0 : tensor + %18424 = arith.index_cast %dim_6496 : index to i64 + %from_elements_6497 = tensor.from_elements %18424, %c1_i64 : tensor<2xi64> + %18425 = stablehlo.dynamic_reshape %18390, %from_elements_6497 : (tensor, tensor<2xi64>) -> tensor + %18426 = stablehlo.concatenate %18423, %18425, dim = 1 : (tensor, tensor) -> tensor + %18427 = "stablehlo.gather"(%17980, %18426) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18428 = shape.shape_of %18421 : tensor -> tensor<2xindex> + %18429 = shape.shape_of %18427 : tensor -> tensor<2xindex> + %18430 = shape.cstr_broadcastable %18428, %18429 : tensor<2xindex>, tensor<2xindex> + %18431 = shape.assuming %18430 -> (tensor) { + %19688 = shape.broadcast %18428, %18429 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18421, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18427, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18432 = shape.shape_of %18431 : tensor -> tensor<2xindex> + %18433 = stablehlo.dynamic_broadcast_in_dim %18431, %18432, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18434 = stablehlo.dynamic_broadcast_in_dim %213, %18432, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18435 = stablehlo.multiply %18433, %18434 : tensor + %dim_6498 = tensor.dim %18395, %c0 : tensor + %18436 = arith.index_cast %dim_6498 : index to i64 + %dim_6499 = tensor.dim %18431, %c0 : tensor + %18437 = arith.index_cast %dim_6499 : index to i64 + %18438 = arith.maxsi %18436, %18437 : i64 + %18439 = arith.index_cast %18438 : i64 to index + %from_elements_6500 = tensor.from_elements %18439, %c4096 : tensor<2xindex> + %18440 = stablehlo.dynamic_broadcast_in_dim %18395, %from_elements_6500, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6501 = tensor.dim %18440, %c0 : tensor + %18441 = arith.index_cast %dim_6501 : index to i64 + %from_elements_6502 = tensor.from_elements %18441, %c4096_i64 : tensor<2xi64> + %18442 = stablehlo.real_dynamic_slice %18435, %c_22, %from_elements_6502, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6503 = tensor.from_elements %18441, %c4096_i64, %c1_i64 : tensor<3xi64> + %18443 = stablehlo.dynamic_reshape %18440, %from_elements_6503 : (tensor, tensor<3xi64>) -> tensor + %18444 = stablehlo.dynamic_iota %from_elements_6503, dim = 1 : (tensor<3xi64>) -> tensor + %18445 = stablehlo.concatenate %18443, %18444, dim = 2 : (tensor, tensor) -> tensor + %18446 = "stablehlo.scatter"(%18383, %18445, %18442) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18447 = stablehlo.reshape %18446 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %18448 = stablehlo.add %17913, %18447 : tensor<3x1x4096xf32> + %18449 = stablehlo.broadcast_in_dim %18448, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %18450 = stablehlo.power %18449, %15 : tensor<3x1x4096xf32> + %18451 = stablehlo.reduce(%18450 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %18452 = stablehlo.reshape %18451 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %18453 = stablehlo.broadcast_in_dim %18452, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %18454 = stablehlo.divide %18453, %21 : tensor<3x1x1xf32> + %18455 = stablehlo.broadcast_in_dim %18454, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %18456 = stablehlo.add %18455, %25 : tensor<3x1x1xf32> + %18457 = stablehlo.rsqrt %18456 : tensor<3x1x1xf32> + %18458 = stablehlo.broadcast_in_dim %18457, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %18459 = stablehlo.multiply %18449, %18458 : tensor<3x1x4096xf32> + %18460 = stablehlo.broadcast_in_dim %18459, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %18461 = stablehlo.multiply %18460, %31 : tensor<3x1x4096xf32> + %18462 = stablehlo.reshape %18461 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %18463 = stablehlo.dot %18462, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %18464 = stablehlo.reshape %18463 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %18465 = stablehlo.dot %18462, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %18466 = stablehlo.reshape %18465 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %18467 = stablehlo.reshape %18464 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %18468 = stablehlo.transpose %18467, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %18469 = stablehlo.reshape %18466 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %18470 = stablehlo.transpose %18469, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %18471 = stablehlo.slice %arg60 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %18472 = stablehlo.slice %arg61 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %18473 = "stablehlo.gather"(%18471, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %18474 = stablehlo.reshape %18473 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %18475 = "stablehlo.gather"(%18472, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %18476 = stablehlo.reshape %18475 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %18477 = stablehlo.broadcast_in_dim %18468, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %18478 = stablehlo.broadcast_in_dim %18474, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %18479 = stablehlo.multiply %18477, %18478 : tensor<3x32x1x128xf32> + %18480 = stablehlo.slice %18468 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %18481 = stablehlo.slice %18468 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %18482 = stablehlo.negate %18481 : tensor<3x32x1x64xf32> + %18483 = stablehlo.concatenate %18482, %18480, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %18484 = stablehlo.broadcast_in_dim %18483, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %18485 = stablehlo.broadcast_in_dim %18476, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %18486 = stablehlo.multiply %18484, %18485 : tensor<3x32x1x128xf32> + %18487 = stablehlo.add %18479, %18486 : tensor<3x32x1x128xf32> + %18488 = stablehlo.broadcast_in_dim %18470, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %18489 = stablehlo.broadcast_in_dim %18474, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %18490 = stablehlo.multiply %18488, %18489 : tensor<3x8x1x128xf32> + %18491 = stablehlo.slice %18470 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %18492 = stablehlo.slice %18470 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %18493 = stablehlo.negate %18492 : tensor<3x8x1x64xf32> + %18494 = stablehlo.concatenate %18493, %18491, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %18495 = stablehlo.broadcast_in_dim %18494, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %18496 = stablehlo.broadcast_in_dim %18476, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %18497 = stablehlo.multiply %18495, %18496 : tensor<3x8x1x128xf32> + %18498 = stablehlo.add %18490, %18497 : tensor<3x8x1x128xf32> + %18499 = stablehlo.concatenate %arg125, %18498, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %18500 = stablehlo.concatenate %arg126, %18470, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %18501 = stablehlo.reshape %18499 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %18502 = stablehlo.broadcast_in_dim %18501, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %18503 = stablehlo.reshape %18502 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %18504 = stablehlo.reshape %18500 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %18505 = stablehlo.broadcast_in_dim %18504, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %18506 = stablehlo.reshape %18505 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %18507 = stablehlo.transpose %18503, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %18508 = stablehlo.reshape %18487 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %18509 = stablehlo.reshape %18507 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %18510 = stablehlo.broadcast_in_dim %18509, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %18511 = stablehlo.dot_general %18508, %18510, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %18512 = stablehlo.reshape %18511 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %18513 = stablehlo.broadcast_in_dim %18512, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %18514 = stablehlo.divide %18513, %89 : tensor<3x32x1x8xf32> + %18515 = stablehlo.custom_call @byteir.softmax(%18514) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %18516 = stablehlo.reshape %18515 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %18517 = stablehlo.reshape %18506 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %18518 = stablehlo.broadcast_in_dim %18517, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %18519 = stablehlo.dot_general %18516, %18518, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %18520 = stablehlo.reshape %18519 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %18521 = stablehlo.transpose %18520, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %18522 = stablehlo.reshape %18521 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %18523 = stablehlo.reshape %18522 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %18524 = stablehlo.dot %18523, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %18525 = stablehlo.reshape %18524 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %18526 = stablehlo.add %18448, %18525 : tensor<3x1x4096xf32> + %18527 = stablehlo.broadcast_in_dim %18526, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %18528 = stablehlo.power %18527, %15 : tensor<3x1x4096xf32> + %18529 = stablehlo.reduce(%18528 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %18530 = stablehlo.reshape %18529 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %18531 = stablehlo.broadcast_in_dim %18530, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %18532 = stablehlo.divide %18531, %21 : tensor<3x1x1xf32> + %18533 = stablehlo.broadcast_in_dim %18532, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %18534 = stablehlo.add %18533, %25 : tensor<3x1x1xf32> + %18535 = stablehlo.rsqrt %18534 : tensor<3x1x1xf32> + %18536 = stablehlo.broadcast_in_dim %18535, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %18537 = stablehlo.multiply %18527, %18536 : tensor<3x1x4096xf32> + %18538 = stablehlo.broadcast_in_dim %18537, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %18539 = stablehlo.multiply %18538, %31 : tensor<3x1x4096xf32> + %18540 = stablehlo.reshape %18539 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %18541 = stablehlo.dot %18540, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %18542 = stablehlo.custom_call @byteir.softmax(%18541) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %18543:2 = stablehlo.custom_call @byteir.top_k(%18542) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %18544 = stablehlo.reduce(%18543#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %18545 = stablehlo.reshape %18544 : (tensor<3xf32>) -> tensor<3x1xf32> + %18546 = stablehlo.broadcast_in_dim %18543#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %18547 = stablehlo.broadcast_in_dim %18545, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %18548 = stablehlo.divide %18546, %18547 : tensor<3x2xf32> + %18549 = stablehlo.reshape %18543#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %18550 = stablehlo.broadcast_in_dim %18549, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %18551 = stablehlo.compare EQ, %18550, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %18552 = stablehlo.convert %18551 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %18553 = stablehlo.transpose %18552, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %18554 = stablehlo.slice %18553 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18555 = stablehlo.reshape %18554 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18556 = stablehlo.custom_call @byteir.non_zero(%18555) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6504 = tensor.dim %18556, %c0 : tensor + %18557 = arith.index_cast %dim_6504 : index to i64 + %from_elements_6505 = tensor.from_elements %18557, %c1_i64 : tensor<2xi64> + %18558 = stablehlo.real_dynamic_slice %18556, %c_22, %from_elements_6505, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6506 = tensor.dim %18558, %c0 : tensor + %18559 = arith.index_cast %dim_6506 : index to i64 + %from_elements_6507 = tensor.from_elements %18559 : tensor<1xi64> + %18560 = stablehlo.dynamic_reshape %18558, %from_elements_6507 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6508 = tensor.from_elements %18557, %c2_i64 : tensor<2xi64> + %18561 = stablehlo.real_dynamic_slice %18556, %c_24, %from_elements_6508, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6509 = tensor.dim %18561, %c0 : tensor + %18562 = arith.index_cast %dim_6509 : index to i64 + %from_elements_6510 = tensor.from_elements %18562 : tensor<1xi64> + %18563 = stablehlo.dynamic_reshape %18561, %from_elements_6510 : (tensor, tensor<1xi64>) -> tensor + %18564 = stablehlo.reshape %18540 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_6511 = tensor.dim %18563, %c0 : tensor + %18565 = arith.index_cast %dim_6511 : index to i64 + %from_elements_6512 = tensor.from_elements %18565, %c1_i64 : tensor<2xi64> + %18566 = stablehlo.dynamic_reshape %18563, %from_elements_6512 : (tensor, tensor<2xi64>) -> tensor + %dim_6513 = tensor.dim %18566, %c0 : tensor + %18567 = arith.index_cast %dim_6513 : index to i64 + %from_elements_6514 = tensor.from_elements %c1_i64, %18567, %c4096_i64 : tensor<3xi64> + %18568 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6514, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6515 = tensor.dim %18568, %c1 : tensor<1x?x4096xi64> + %18569 = arith.index_cast %dim_6515 : index to i64 + %from_elements_6516 = tensor.from_elements %c1_i64, %18569, %c4096_i64, %c1_i64 : tensor<4xi64> + %18570 = stablehlo.dynamic_reshape %18568, %from_elements_6516 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18571 = stablehlo.dynamic_broadcast_in_dim %18566, %from_elements_6514, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6517 = tensor.dim %18571, %c1 : tensor<1x?x4096xi64> + %18572 = arith.index_cast %dim_6517 : index to i64 + %from_elements_6518 = tensor.from_elements %c1_i64, %18572, %c4096_i64, %c1_i64 : tensor<4xi64> + %18573 = stablehlo.dynamic_reshape %18571, %from_elements_6518 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18574 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6514, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6519 = tensor.dim %18574, %c1 : tensor<1x?x4096xi64> + %18575 = arith.index_cast %dim_6519 : index to i64 + %from_elements_6520 = tensor.from_elements %c1_i64, %18575, %c4096_i64, %c1_i64 : tensor<4xi64> + %18576 = stablehlo.dynamic_reshape %18574, %from_elements_6520 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18577 = stablehlo.concatenate %18570, %18573, %18576, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18578 = "stablehlo.gather"(%18564, %18577) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18579 = shape.shape_of %18578 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18580 = shape.num_elements %18579 : tensor<3xindex> -> index + %18581 = stablehlo.compute_reshape_shape %18580, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18582 = stablehlo.dynamic_reshape %18578, %18581 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18583 = stablehlo.dot %18582, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18584 = stablehlo.logistic %18583 : tensor + %18585 = shape.shape_of %18584 : tensor -> tensor<2xindex> + %18586 = shape.shape_of %18583 : tensor -> tensor<2xindex> + %18587 = shape.cstr_broadcastable %18585, %18586 : tensor<2xindex>, tensor<2xindex> + %18588 = shape.assuming %18587 -> (tensor) { + %19688 = shape.broadcast %18585, %18586 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18584, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18583, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18589 = shape.shape_of %18588 : tensor -> tensor<2xindex> + %18590 = shape.cstr_broadcastable %18589, %18586 : tensor<2xindex>, tensor<2xindex> + %18591 = shape.assuming %18590 -> (tensor) { + %19688 = shape.broadcast %18589, %18586 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18588, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18583, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18592 = stablehlo.dot %18591, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %18593 = stablehlo.reshape %18548 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_6521 = tensor.dim %18563, %c0 : tensor + %18594 = arith.index_cast %dim_6521 : index to i64 + %from_elements_6522 = tensor.from_elements %18594, %c1_i64 : tensor<2xi64> + %18595 = stablehlo.dynamic_reshape %18563, %from_elements_6522 : (tensor, tensor<2xi64>) -> tensor + %dim_6523 = tensor.dim %18560, %c0 : tensor + %18596 = arith.index_cast %dim_6523 : index to i64 + %from_elements_6524 = tensor.from_elements %18596, %c1_i64 : tensor<2xi64> + %18597 = stablehlo.dynamic_reshape %18560, %from_elements_6524 : (tensor, tensor<2xi64>) -> tensor + %18598 = stablehlo.concatenate %18595, %18597, dim = 1 : (tensor, tensor) -> tensor + %18599 = "stablehlo.gather"(%18593, %18598) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18600 = shape.shape_of %18592 : tensor -> tensor<2xindex> + %18601 = shape.shape_of %18599 : tensor -> tensor<2xindex> + %18602 = shape.cstr_broadcastable %18600, %18601 : tensor<2xindex>, tensor<2xindex> + %18603 = shape.assuming %18602 -> (tensor) { + %19688 = shape.broadcast %18600, %18601 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18592, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18599, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18604 = shape.shape_of %18603 : tensor -> tensor<2xindex> + %18605 = stablehlo.dynamic_broadcast_in_dim %18603, %18604, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18606 = stablehlo.dynamic_broadcast_in_dim %213, %18604, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18607 = stablehlo.multiply %18605, %18606 : tensor + %dim_6525 = tensor.dim %18566, %c0 : tensor + %18608 = arith.index_cast %dim_6525 : index to i64 + %dim_6526 = tensor.dim %18603, %c0 : tensor + %18609 = arith.index_cast %dim_6526 : index to i64 + %18610 = arith.maxsi %18608, %18609 : i64 + %18611 = arith.index_cast %18610 : i64 to index + %from_elements_6527 = tensor.from_elements %18611, %c4096 : tensor<2xindex> + %18612 = stablehlo.dynamic_broadcast_in_dim %18566, %from_elements_6527, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6528 = tensor.dim %18612, %c0 : tensor + %18613 = arith.index_cast %dim_6528 : index to i64 + %from_elements_6529 = tensor.from_elements %18613, %c4096_i64 : tensor<2xi64> + %18614 = stablehlo.real_dynamic_slice %18607, %c_22, %from_elements_6529, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6530 = tensor.from_elements %18613, %c4096_i64, %c1_i64 : tensor<3xi64> + %18615 = stablehlo.dynamic_reshape %18612, %from_elements_6530 : (tensor, tensor<3xi64>) -> tensor + %18616 = stablehlo.dynamic_iota %from_elements_6530, dim = 1 : (tensor<3xi64>) -> tensor + %18617 = stablehlo.concatenate %18615, %18616, dim = 2 : (tensor, tensor) -> tensor + %18618 = "stablehlo.scatter"(%cst_2, %18617, %18614) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18619 = stablehlo.slice %18553 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18620 = stablehlo.reshape %18619 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18621 = stablehlo.custom_call @byteir.non_zero(%18620) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6531 = tensor.dim %18621, %c0 : tensor + %18622 = arith.index_cast %dim_6531 : index to i64 + %from_elements_6532 = tensor.from_elements %18622, %c1_i64 : tensor<2xi64> + %18623 = stablehlo.real_dynamic_slice %18621, %c_22, %from_elements_6532, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6533 = tensor.dim %18623, %c0 : tensor + %18624 = arith.index_cast %dim_6533 : index to i64 + %from_elements_6534 = tensor.from_elements %18624 : tensor<1xi64> + %18625 = stablehlo.dynamic_reshape %18623, %from_elements_6534 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6535 = tensor.from_elements %18622, %c2_i64 : tensor<2xi64> + %18626 = stablehlo.real_dynamic_slice %18621, %c_24, %from_elements_6535, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6536 = tensor.dim %18626, %c0 : tensor + %18627 = arith.index_cast %dim_6536 : index to i64 + %from_elements_6537 = tensor.from_elements %18627 : tensor<1xi64> + %18628 = stablehlo.dynamic_reshape %18626, %from_elements_6537 : (tensor, tensor<1xi64>) -> tensor + %dim_6538 = tensor.dim %18628, %c0 : tensor + %18629 = arith.index_cast %dim_6538 : index to i64 + %from_elements_6539 = tensor.from_elements %18629, %c1_i64 : tensor<2xi64> + %18630 = stablehlo.dynamic_reshape %18628, %from_elements_6539 : (tensor, tensor<2xi64>) -> tensor + %dim_6540 = tensor.dim %18630, %c0 : tensor + %18631 = arith.index_cast %dim_6540 : index to i64 + %from_elements_6541 = tensor.from_elements %c1_i64, %18631, %c4096_i64 : tensor<3xi64> + %18632 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6541, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6542 = tensor.dim %18632, %c1 : tensor<1x?x4096xi64> + %18633 = arith.index_cast %dim_6542 : index to i64 + %from_elements_6543 = tensor.from_elements %c1_i64, %18633, %c4096_i64, %c1_i64 : tensor<4xi64> + %18634 = stablehlo.dynamic_reshape %18632, %from_elements_6543 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18635 = stablehlo.dynamic_broadcast_in_dim %18630, %from_elements_6541, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6544 = tensor.dim %18635, %c1 : tensor<1x?x4096xi64> + %18636 = arith.index_cast %dim_6544 : index to i64 + %from_elements_6545 = tensor.from_elements %c1_i64, %18636, %c4096_i64, %c1_i64 : tensor<4xi64> + %18637 = stablehlo.dynamic_reshape %18635, %from_elements_6545 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18638 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6541, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6546 = tensor.dim %18638, %c1 : tensor<1x?x4096xi64> + %18639 = arith.index_cast %dim_6546 : index to i64 + %from_elements_6547 = tensor.from_elements %c1_i64, %18639, %c4096_i64, %c1_i64 : tensor<4xi64> + %18640 = stablehlo.dynamic_reshape %18638, %from_elements_6547 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18641 = stablehlo.concatenate %18634, %18637, %18640, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18642 = "stablehlo.gather"(%18564, %18641) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18643 = shape.shape_of %18642 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18644 = shape.num_elements %18643 : tensor<3xindex> -> index + %18645 = stablehlo.compute_reshape_shape %18644, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18646 = stablehlo.dynamic_reshape %18642, %18645 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18647 = stablehlo.dot %18646, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18648 = stablehlo.logistic %18647 : tensor + %18649 = shape.shape_of %18648 : tensor -> tensor<2xindex> + %18650 = shape.shape_of %18647 : tensor -> tensor<2xindex> + %18651 = shape.cstr_broadcastable %18649, %18650 : tensor<2xindex>, tensor<2xindex> + %18652 = shape.assuming %18651 -> (tensor) { + %19688 = shape.broadcast %18649, %18650 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18648, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18647, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18653 = shape.shape_of %18652 : tensor -> tensor<2xindex> + %18654 = shape.cstr_broadcastable %18653, %18650 : tensor<2xindex>, tensor<2xindex> + %18655 = shape.assuming %18654 -> (tensor) { + %19688 = shape.broadcast %18653, %18650 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18652, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18647, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18656 = stablehlo.dot %18655, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6548 = tensor.dim %18628, %c0 : tensor + %18657 = arith.index_cast %dim_6548 : index to i64 + %from_elements_6549 = tensor.from_elements %18657, %c1_i64 : tensor<2xi64> + %18658 = stablehlo.dynamic_reshape %18628, %from_elements_6549 : (tensor, tensor<2xi64>) -> tensor + %dim_6550 = tensor.dim %18625, %c0 : tensor + %18659 = arith.index_cast %dim_6550 : index to i64 + %from_elements_6551 = tensor.from_elements %18659, %c1_i64 : tensor<2xi64> + %18660 = stablehlo.dynamic_reshape %18625, %from_elements_6551 : (tensor, tensor<2xi64>) -> tensor + %18661 = stablehlo.concatenate %18658, %18660, dim = 1 : (tensor, tensor) -> tensor + %18662 = "stablehlo.gather"(%18593, %18661) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18663 = shape.shape_of %18656 : tensor -> tensor<2xindex> + %18664 = shape.shape_of %18662 : tensor -> tensor<2xindex> + %18665 = shape.cstr_broadcastable %18663, %18664 : tensor<2xindex>, tensor<2xindex> + %18666 = shape.assuming %18665 -> (tensor) { + %19688 = shape.broadcast %18663, %18664 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18656, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18662, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18667 = shape.shape_of %18666 : tensor -> tensor<2xindex> + %18668 = stablehlo.dynamic_broadcast_in_dim %18666, %18667, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18669 = stablehlo.dynamic_broadcast_in_dim %213, %18667, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18670 = stablehlo.multiply %18668, %18669 : tensor + %dim_6552 = tensor.dim %18630, %c0 : tensor + %18671 = arith.index_cast %dim_6552 : index to i64 + %dim_6553 = tensor.dim %18666, %c0 : tensor + %18672 = arith.index_cast %dim_6553 : index to i64 + %18673 = arith.maxsi %18671, %18672 : i64 + %18674 = arith.index_cast %18673 : i64 to index + %from_elements_6554 = tensor.from_elements %18674, %c4096 : tensor<2xindex> + %18675 = stablehlo.dynamic_broadcast_in_dim %18630, %from_elements_6554, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6555 = tensor.dim %18675, %c0 : tensor + %18676 = arith.index_cast %dim_6555 : index to i64 + %from_elements_6556 = tensor.from_elements %18676, %c4096_i64 : tensor<2xi64> + %18677 = stablehlo.real_dynamic_slice %18670, %c_22, %from_elements_6556, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6557 = tensor.from_elements %18676, %c4096_i64, %c1_i64 : tensor<3xi64> + %18678 = stablehlo.dynamic_reshape %18675, %from_elements_6557 : (tensor, tensor<3xi64>) -> tensor + %18679 = stablehlo.dynamic_iota %from_elements_6557, dim = 1 : (tensor<3xi64>) -> tensor + %18680 = stablehlo.concatenate %18678, %18679, dim = 2 : (tensor, tensor) -> tensor + %18681 = "stablehlo.scatter"(%18618, %18680, %18677) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18682 = stablehlo.slice %18553 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18683 = stablehlo.reshape %18682 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18684 = stablehlo.custom_call @byteir.non_zero(%18683) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6558 = tensor.dim %18684, %c0 : tensor + %18685 = arith.index_cast %dim_6558 : index to i64 + %from_elements_6559 = tensor.from_elements %18685, %c1_i64 : tensor<2xi64> + %18686 = stablehlo.real_dynamic_slice %18684, %c_22, %from_elements_6559, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6560 = tensor.dim %18686, %c0 : tensor + %18687 = arith.index_cast %dim_6560 : index to i64 + %from_elements_6561 = tensor.from_elements %18687 : tensor<1xi64> + %18688 = stablehlo.dynamic_reshape %18686, %from_elements_6561 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6562 = tensor.from_elements %18685, %c2_i64 : tensor<2xi64> + %18689 = stablehlo.real_dynamic_slice %18684, %c_24, %from_elements_6562, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6563 = tensor.dim %18689, %c0 : tensor + %18690 = arith.index_cast %dim_6563 : index to i64 + %from_elements_6564 = tensor.from_elements %18690 : tensor<1xi64> + %18691 = stablehlo.dynamic_reshape %18689, %from_elements_6564 : (tensor, tensor<1xi64>) -> tensor + %dim_6565 = tensor.dim %18691, %c0 : tensor + %18692 = arith.index_cast %dim_6565 : index to i64 + %from_elements_6566 = tensor.from_elements %18692, %c1_i64 : tensor<2xi64> + %18693 = stablehlo.dynamic_reshape %18691, %from_elements_6566 : (tensor, tensor<2xi64>) -> tensor + %dim_6567 = tensor.dim %18693, %c0 : tensor + %18694 = arith.index_cast %dim_6567 : index to i64 + %from_elements_6568 = tensor.from_elements %c1_i64, %18694, %c4096_i64 : tensor<3xi64> + %18695 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6568, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6569 = tensor.dim %18695, %c1 : tensor<1x?x4096xi64> + %18696 = arith.index_cast %dim_6569 : index to i64 + %from_elements_6570 = tensor.from_elements %c1_i64, %18696, %c4096_i64, %c1_i64 : tensor<4xi64> + %18697 = stablehlo.dynamic_reshape %18695, %from_elements_6570 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18698 = stablehlo.dynamic_broadcast_in_dim %18693, %from_elements_6568, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6571 = tensor.dim %18698, %c1 : tensor<1x?x4096xi64> + %18699 = arith.index_cast %dim_6571 : index to i64 + %from_elements_6572 = tensor.from_elements %c1_i64, %18699, %c4096_i64, %c1_i64 : tensor<4xi64> + %18700 = stablehlo.dynamic_reshape %18698, %from_elements_6572 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18701 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6568, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6573 = tensor.dim %18701, %c1 : tensor<1x?x4096xi64> + %18702 = arith.index_cast %dim_6573 : index to i64 + %from_elements_6574 = tensor.from_elements %c1_i64, %18702, %c4096_i64, %c1_i64 : tensor<4xi64> + %18703 = stablehlo.dynamic_reshape %18701, %from_elements_6574 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18704 = stablehlo.concatenate %18697, %18700, %18703, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18705 = "stablehlo.gather"(%18564, %18704) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18706 = shape.shape_of %18705 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18707 = shape.num_elements %18706 : tensor<3xindex> -> index + %18708 = stablehlo.compute_reshape_shape %18707, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18709 = stablehlo.dynamic_reshape %18705, %18708 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18710 = stablehlo.dot %18709, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18711 = stablehlo.logistic %18710 : tensor + %18712 = shape.shape_of %18711 : tensor -> tensor<2xindex> + %18713 = shape.shape_of %18710 : tensor -> tensor<2xindex> + %18714 = shape.cstr_broadcastable %18712, %18713 : tensor<2xindex>, tensor<2xindex> + %18715 = shape.assuming %18714 -> (tensor) { + %19688 = shape.broadcast %18712, %18713 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18711, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18710, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18716 = shape.shape_of %18715 : tensor -> tensor<2xindex> + %18717 = shape.cstr_broadcastable %18716, %18713 : tensor<2xindex>, tensor<2xindex> + %18718 = shape.assuming %18717 -> (tensor) { + %19688 = shape.broadcast %18716, %18713 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18715, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18710, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18719 = stablehlo.dot %18718, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6575 = tensor.dim %18691, %c0 : tensor + %18720 = arith.index_cast %dim_6575 : index to i64 + %from_elements_6576 = tensor.from_elements %18720, %c1_i64 : tensor<2xi64> + %18721 = stablehlo.dynamic_reshape %18691, %from_elements_6576 : (tensor, tensor<2xi64>) -> tensor + %dim_6577 = tensor.dim %18688, %c0 : tensor + %18722 = arith.index_cast %dim_6577 : index to i64 + %from_elements_6578 = tensor.from_elements %18722, %c1_i64 : tensor<2xi64> + %18723 = stablehlo.dynamic_reshape %18688, %from_elements_6578 : (tensor, tensor<2xi64>) -> tensor + %18724 = stablehlo.concatenate %18721, %18723, dim = 1 : (tensor, tensor) -> tensor + %18725 = "stablehlo.gather"(%18593, %18724) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18726 = shape.shape_of %18719 : tensor -> tensor<2xindex> + %18727 = shape.shape_of %18725 : tensor -> tensor<2xindex> + %18728 = shape.cstr_broadcastable %18726, %18727 : tensor<2xindex>, tensor<2xindex> + %18729 = shape.assuming %18728 -> (tensor) { + %19688 = shape.broadcast %18726, %18727 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18719, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18725, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18730 = shape.shape_of %18729 : tensor -> tensor<2xindex> + %18731 = stablehlo.dynamic_broadcast_in_dim %18729, %18730, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18732 = stablehlo.dynamic_broadcast_in_dim %213, %18730, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18733 = stablehlo.multiply %18731, %18732 : tensor + %dim_6579 = tensor.dim %18693, %c0 : tensor + %18734 = arith.index_cast %dim_6579 : index to i64 + %dim_6580 = tensor.dim %18729, %c0 : tensor + %18735 = arith.index_cast %dim_6580 : index to i64 + %18736 = arith.maxsi %18734, %18735 : i64 + %18737 = arith.index_cast %18736 : i64 to index + %from_elements_6581 = tensor.from_elements %18737, %c4096 : tensor<2xindex> + %18738 = stablehlo.dynamic_broadcast_in_dim %18693, %from_elements_6581, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6582 = tensor.dim %18738, %c0 : tensor + %18739 = arith.index_cast %dim_6582 : index to i64 + %from_elements_6583 = tensor.from_elements %18739, %c4096_i64 : tensor<2xi64> + %18740 = stablehlo.real_dynamic_slice %18733, %c_22, %from_elements_6583, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6584 = tensor.from_elements %18739, %c4096_i64, %c1_i64 : tensor<3xi64> + %18741 = stablehlo.dynamic_reshape %18738, %from_elements_6584 : (tensor, tensor<3xi64>) -> tensor + %18742 = stablehlo.dynamic_iota %from_elements_6584, dim = 1 : (tensor<3xi64>) -> tensor + %18743 = stablehlo.concatenate %18741, %18742, dim = 2 : (tensor, tensor) -> tensor + %18744 = "stablehlo.scatter"(%18681, %18743, %18740) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18745 = stablehlo.slice %18553 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18746 = stablehlo.reshape %18745 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18747 = stablehlo.custom_call @byteir.non_zero(%18746) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6585 = tensor.dim %18747, %c0 : tensor + %18748 = arith.index_cast %dim_6585 : index to i64 + %from_elements_6586 = tensor.from_elements %18748, %c1_i64 : tensor<2xi64> + %18749 = stablehlo.real_dynamic_slice %18747, %c_22, %from_elements_6586, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6587 = tensor.dim %18749, %c0 : tensor + %18750 = arith.index_cast %dim_6587 : index to i64 + %from_elements_6588 = tensor.from_elements %18750 : tensor<1xi64> + %18751 = stablehlo.dynamic_reshape %18749, %from_elements_6588 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6589 = tensor.from_elements %18748, %c2_i64 : tensor<2xi64> + %18752 = stablehlo.real_dynamic_slice %18747, %c_24, %from_elements_6589, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6590 = tensor.dim %18752, %c0 : tensor + %18753 = arith.index_cast %dim_6590 : index to i64 + %from_elements_6591 = tensor.from_elements %18753 : tensor<1xi64> + %18754 = stablehlo.dynamic_reshape %18752, %from_elements_6591 : (tensor, tensor<1xi64>) -> tensor + %dim_6592 = tensor.dim %18754, %c0 : tensor + %18755 = arith.index_cast %dim_6592 : index to i64 + %from_elements_6593 = tensor.from_elements %18755, %c1_i64 : tensor<2xi64> + %18756 = stablehlo.dynamic_reshape %18754, %from_elements_6593 : (tensor, tensor<2xi64>) -> tensor + %dim_6594 = tensor.dim %18756, %c0 : tensor + %18757 = arith.index_cast %dim_6594 : index to i64 + %from_elements_6595 = tensor.from_elements %c1_i64, %18757, %c4096_i64 : tensor<3xi64> + %18758 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6595, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6596 = tensor.dim %18758, %c1 : tensor<1x?x4096xi64> + %18759 = arith.index_cast %dim_6596 : index to i64 + %from_elements_6597 = tensor.from_elements %c1_i64, %18759, %c4096_i64, %c1_i64 : tensor<4xi64> + %18760 = stablehlo.dynamic_reshape %18758, %from_elements_6597 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18761 = stablehlo.dynamic_broadcast_in_dim %18756, %from_elements_6595, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6598 = tensor.dim %18761, %c1 : tensor<1x?x4096xi64> + %18762 = arith.index_cast %dim_6598 : index to i64 + %from_elements_6599 = tensor.from_elements %c1_i64, %18762, %c4096_i64, %c1_i64 : tensor<4xi64> + %18763 = stablehlo.dynamic_reshape %18761, %from_elements_6599 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18764 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6595, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6600 = tensor.dim %18764, %c1 : tensor<1x?x4096xi64> + %18765 = arith.index_cast %dim_6600 : index to i64 + %from_elements_6601 = tensor.from_elements %c1_i64, %18765, %c4096_i64, %c1_i64 : tensor<4xi64> + %18766 = stablehlo.dynamic_reshape %18764, %from_elements_6601 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18767 = stablehlo.concatenate %18760, %18763, %18766, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18768 = "stablehlo.gather"(%18564, %18767) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18769 = shape.shape_of %18768 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18770 = shape.num_elements %18769 : tensor<3xindex> -> index + %18771 = stablehlo.compute_reshape_shape %18770, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18772 = stablehlo.dynamic_reshape %18768, %18771 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18773 = stablehlo.dot %18772, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18774 = stablehlo.logistic %18773 : tensor + %18775 = shape.shape_of %18774 : tensor -> tensor<2xindex> + %18776 = shape.shape_of %18773 : tensor -> tensor<2xindex> + %18777 = shape.cstr_broadcastable %18775, %18776 : tensor<2xindex>, tensor<2xindex> + %18778 = shape.assuming %18777 -> (tensor) { + %19688 = shape.broadcast %18775, %18776 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18774, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18773, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18779 = shape.shape_of %18778 : tensor -> tensor<2xindex> + %18780 = shape.cstr_broadcastable %18779, %18776 : tensor<2xindex>, tensor<2xindex> + %18781 = shape.assuming %18780 -> (tensor) { + %19688 = shape.broadcast %18779, %18776 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18778, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18773, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18782 = stablehlo.dot %18781, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6602 = tensor.dim %18754, %c0 : tensor + %18783 = arith.index_cast %dim_6602 : index to i64 + %from_elements_6603 = tensor.from_elements %18783, %c1_i64 : tensor<2xi64> + %18784 = stablehlo.dynamic_reshape %18754, %from_elements_6603 : (tensor, tensor<2xi64>) -> tensor + %dim_6604 = tensor.dim %18751, %c0 : tensor + %18785 = arith.index_cast %dim_6604 : index to i64 + %from_elements_6605 = tensor.from_elements %18785, %c1_i64 : tensor<2xi64> + %18786 = stablehlo.dynamic_reshape %18751, %from_elements_6605 : (tensor, tensor<2xi64>) -> tensor + %18787 = stablehlo.concatenate %18784, %18786, dim = 1 : (tensor, tensor) -> tensor + %18788 = "stablehlo.gather"(%18593, %18787) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18789 = shape.shape_of %18782 : tensor -> tensor<2xindex> + %18790 = shape.shape_of %18788 : tensor -> tensor<2xindex> + %18791 = shape.cstr_broadcastable %18789, %18790 : tensor<2xindex>, tensor<2xindex> + %18792 = shape.assuming %18791 -> (tensor) { + %19688 = shape.broadcast %18789, %18790 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18782, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18788, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18793 = shape.shape_of %18792 : tensor -> tensor<2xindex> + %18794 = stablehlo.dynamic_broadcast_in_dim %18792, %18793, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18795 = stablehlo.dynamic_broadcast_in_dim %213, %18793, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18796 = stablehlo.multiply %18794, %18795 : tensor + %dim_6606 = tensor.dim %18756, %c0 : tensor + %18797 = arith.index_cast %dim_6606 : index to i64 + %dim_6607 = tensor.dim %18792, %c0 : tensor + %18798 = arith.index_cast %dim_6607 : index to i64 + %18799 = arith.maxsi %18797, %18798 : i64 + %18800 = arith.index_cast %18799 : i64 to index + %from_elements_6608 = tensor.from_elements %18800, %c4096 : tensor<2xindex> + %18801 = stablehlo.dynamic_broadcast_in_dim %18756, %from_elements_6608, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6609 = tensor.dim %18801, %c0 : tensor + %18802 = arith.index_cast %dim_6609 : index to i64 + %from_elements_6610 = tensor.from_elements %18802, %c4096_i64 : tensor<2xi64> + %18803 = stablehlo.real_dynamic_slice %18796, %c_22, %from_elements_6610, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6611 = tensor.from_elements %18802, %c4096_i64, %c1_i64 : tensor<3xi64> + %18804 = stablehlo.dynamic_reshape %18801, %from_elements_6611 : (tensor, tensor<3xi64>) -> tensor + %18805 = stablehlo.dynamic_iota %from_elements_6611, dim = 1 : (tensor<3xi64>) -> tensor + %18806 = stablehlo.concatenate %18804, %18805, dim = 2 : (tensor, tensor) -> tensor + %18807 = "stablehlo.scatter"(%18744, %18806, %18803) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18808 = stablehlo.slice %18553 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18809 = stablehlo.reshape %18808 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18810 = stablehlo.custom_call @byteir.non_zero(%18809) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6612 = tensor.dim %18810, %c0 : tensor + %18811 = arith.index_cast %dim_6612 : index to i64 + %from_elements_6613 = tensor.from_elements %18811, %c1_i64 : tensor<2xi64> + %18812 = stablehlo.real_dynamic_slice %18810, %c_22, %from_elements_6613, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6614 = tensor.dim %18812, %c0 : tensor + %18813 = arith.index_cast %dim_6614 : index to i64 + %from_elements_6615 = tensor.from_elements %18813 : tensor<1xi64> + %18814 = stablehlo.dynamic_reshape %18812, %from_elements_6615 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6616 = tensor.from_elements %18811, %c2_i64 : tensor<2xi64> + %18815 = stablehlo.real_dynamic_slice %18810, %c_24, %from_elements_6616, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6617 = tensor.dim %18815, %c0 : tensor + %18816 = arith.index_cast %dim_6617 : index to i64 + %from_elements_6618 = tensor.from_elements %18816 : tensor<1xi64> + %18817 = stablehlo.dynamic_reshape %18815, %from_elements_6618 : (tensor, tensor<1xi64>) -> tensor + %dim_6619 = tensor.dim %18817, %c0 : tensor + %18818 = arith.index_cast %dim_6619 : index to i64 + %from_elements_6620 = tensor.from_elements %18818, %c1_i64 : tensor<2xi64> + %18819 = stablehlo.dynamic_reshape %18817, %from_elements_6620 : (tensor, tensor<2xi64>) -> tensor + %dim_6621 = tensor.dim %18819, %c0 : tensor + %18820 = arith.index_cast %dim_6621 : index to i64 + %from_elements_6622 = tensor.from_elements %c1_i64, %18820, %c4096_i64 : tensor<3xi64> + %18821 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6622, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6623 = tensor.dim %18821, %c1 : tensor<1x?x4096xi64> + %18822 = arith.index_cast %dim_6623 : index to i64 + %from_elements_6624 = tensor.from_elements %c1_i64, %18822, %c4096_i64, %c1_i64 : tensor<4xi64> + %18823 = stablehlo.dynamic_reshape %18821, %from_elements_6624 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18824 = stablehlo.dynamic_broadcast_in_dim %18819, %from_elements_6622, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6625 = tensor.dim %18824, %c1 : tensor<1x?x4096xi64> + %18825 = arith.index_cast %dim_6625 : index to i64 + %from_elements_6626 = tensor.from_elements %c1_i64, %18825, %c4096_i64, %c1_i64 : tensor<4xi64> + %18826 = stablehlo.dynamic_reshape %18824, %from_elements_6626 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18827 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6622, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6627 = tensor.dim %18827, %c1 : tensor<1x?x4096xi64> + %18828 = arith.index_cast %dim_6627 : index to i64 + %from_elements_6628 = tensor.from_elements %c1_i64, %18828, %c4096_i64, %c1_i64 : tensor<4xi64> + %18829 = stablehlo.dynamic_reshape %18827, %from_elements_6628 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18830 = stablehlo.concatenate %18823, %18826, %18829, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18831 = "stablehlo.gather"(%18564, %18830) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18832 = shape.shape_of %18831 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18833 = shape.num_elements %18832 : tensor<3xindex> -> index + %18834 = stablehlo.compute_reshape_shape %18833, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18835 = stablehlo.dynamic_reshape %18831, %18834 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18836 = stablehlo.dot %18835, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18837 = stablehlo.logistic %18836 : tensor + %18838 = shape.shape_of %18837 : tensor -> tensor<2xindex> + %18839 = shape.shape_of %18836 : tensor -> tensor<2xindex> + %18840 = shape.cstr_broadcastable %18838, %18839 : tensor<2xindex>, tensor<2xindex> + %18841 = shape.assuming %18840 -> (tensor) { + %19688 = shape.broadcast %18838, %18839 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18837, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18836, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18842 = shape.shape_of %18841 : tensor -> tensor<2xindex> + %18843 = shape.cstr_broadcastable %18842, %18839 : tensor<2xindex>, tensor<2xindex> + %18844 = shape.assuming %18843 -> (tensor) { + %19688 = shape.broadcast %18842, %18839 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18841, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18836, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18845 = stablehlo.dot %18844, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6629 = tensor.dim %18817, %c0 : tensor + %18846 = arith.index_cast %dim_6629 : index to i64 + %from_elements_6630 = tensor.from_elements %18846, %c1_i64 : tensor<2xi64> + %18847 = stablehlo.dynamic_reshape %18817, %from_elements_6630 : (tensor, tensor<2xi64>) -> tensor + %dim_6631 = tensor.dim %18814, %c0 : tensor + %18848 = arith.index_cast %dim_6631 : index to i64 + %from_elements_6632 = tensor.from_elements %18848, %c1_i64 : tensor<2xi64> + %18849 = stablehlo.dynamic_reshape %18814, %from_elements_6632 : (tensor, tensor<2xi64>) -> tensor + %18850 = stablehlo.concatenate %18847, %18849, dim = 1 : (tensor, tensor) -> tensor + %18851 = "stablehlo.gather"(%18593, %18850) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18852 = shape.shape_of %18845 : tensor -> tensor<2xindex> + %18853 = shape.shape_of %18851 : tensor -> tensor<2xindex> + %18854 = shape.cstr_broadcastable %18852, %18853 : tensor<2xindex>, tensor<2xindex> + %18855 = shape.assuming %18854 -> (tensor) { + %19688 = shape.broadcast %18852, %18853 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18845, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18851, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18856 = shape.shape_of %18855 : tensor -> tensor<2xindex> + %18857 = stablehlo.dynamic_broadcast_in_dim %18855, %18856, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18858 = stablehlo.dynamic_broadcast_in_dim %213, %18856, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18859 = stablehlo.multiply %18857, %18858 : tensor + %dim_6633 = tensor.dim %18819, %c0 : tensor + %18860 = arith.index_cast %dim_6633 : index to i64 + %dim_6634 = tensor.dim %18855, %c0 : tensor + %18861 = arith.index_cast %dim_6634 : index to i64 + %18862 = arith.maxsi %18860, %18861 : i64 + %18863 = arith.index_cast %18862 : i64 to index + %from_elements_6635 = tensor.from_elements %18863, %c4096 : tensor<2xindex> + %18864 = stablehlo.dynamic_broadcast_in_dim %18819, %from_elements_6635, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6636 = tensor.dim %18864, %c0 : tensor + %18865 = arith.index_cast %dim_6636 : index to i64 + %from_elements_6637 = tensor.from_elements %18865, %c4096_i64 : tensor<2xi64> + %18866 = stablehlo.real_dynamic_slice %18859, %c_22, %from_elements_6637, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6638 = tensor.from_elements %18865, %c4096_i64, %c1_i64 : tensor<3xi64> + %18867 = stablehlo.dynamic_reshape %18864, %from_elements_6638 : (tensor, tensor<3xi64>) -> tensor + %18868 = stablehlo.dynamic_iota %from_elements_6638, dim = 1 : (tensor<3xi64>) -> tensor + %18869 = stablehlo.concatenate %18867, %18868, dim = 2 : (tensor, tensor) -> tensor + %18870 = "stablehlo.scatter"(%18807, %18869, %18866) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18871 = stablehlo.slice %18553 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18872 = stablehlo.reshape %18871 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18873 = stablehlo.custom_call @byteir.non_zero(%18872) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6639 = tensor.dim %18873, %c0 : tensor + %18874 = arith.index_cast %dim_6639 : index to i64 + %from_elements_6640 = tensor.from_elements %18874, %c1_i64 : tensor<2xi64> + %18875 = stablehlo.real_dynamic_slice %18873, %c_22, %from_elements_6640, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6641 = tensor.dim %18875, %c0 : tensor + %18876 = arith.index_cast %dim_6641 : index to i64 + %from_elements_6642 = tensor.from_elements %18876 : tensor<1xi64> + %18877 = stablehlo.dynamic_reshape %18875, %from_elements_6642 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6643 = tensor.from_elements %18874, %c2_i64 : tensor<2xi64> + %18878 = stablehlo.real_dynamic_slice %18873, %c_24, %from_elements_6643, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6644 = tensor.dim %18878, %c0 : tensor + %18879 = arith.index_cast %dim_6644 : index to i64 + %from_elements_6645 = tensor.from_elements %18879 : tensor<1xi64> + %18880 = stablehlo.dynamic_reshape %18878, %from_elements_6645 : (tensor, tensor<1xi64>) -> tensor + %dim_6646 = tensor.dim %18880, %c0 : tensor + %18881 = arith.index_cast %dim_6646 : index to i64 + %from_elements_6647 = tensor.from_elements %18881, %c1_i64 : tensor<2xi64> + %18882 = stablehlo.dynamic_reshape %18880, %from_elements_6647 : (tensor, tensor<2xi64>) -> tensor + %dim_6648 = tensor.dim %18882, %c0 : tensor + %18883 = arith.index_cast %dim_6648 : index to i64 + %from_elements_6649 = tensor.from_elements %c1_i64, %18883, %c4096_i64 : tensor<3xi64> + %18884 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6649, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6650 = tensor.dim %18884, %c1 : tensor<1x?x4096xi64> + %18885 = arith.index_cast %dim_6650 : index to i64 + %from_elements_6651 = tensor.from_elements %c1_i64, %18885, %c4096_i64, %c1_i64 : tensor<4xi64> + %18886 = stablehlo.dynamic_reshape %18884, %from_elements_6651 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18887 = stablehlo.dynamic_broadcast_in_dim %18882, %from_elements_6649, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6652 = tensor.dim %18887, %c1 : tensor<1x?x4096xi64> + %18888 = arith.index_cast %dim_6652 : index to i64 + %from_elements_6653 = tensor.from_elements %c1_i64, %18888, %c4096_i64, %c1_i64 : tensor<4xi64> + %18889 = stablehlo.dynamic_reshape %18887, %from_elements_6653 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18890 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6649, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6654 = tensor.dim %18890, %c1 : tensor<1x?x4096xi64> + %18891 = arith.index_cast %dim_6654 : index to i64 + %from_elements_6655 = tensor.from_elements %c1_i64, %18891, %c4096_i64, %c1_i64 : tensor<4xi64> + %18892 = stablehlo.dynamic_reshape %18890, %from_elements_6655 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18893 = stablehlo.concatenate %18886, %18889, %18892, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18894 = "stablehlo.gather"(%18564, %18893) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18895 = shape.shape_of %18894 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18896 = shape.num_elements %18895 : tensor<3xindex> -> index + %18897 = stablehlo.compute_reshape_shape %18896, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18898 = stablehlo.dynamic_reshape %18894, %18897 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18899 = stablehlo.dot %18898, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18900 = stablehlo.logistic %18899 : tensor + %18901 = shape.shape_of %18900 : tensor -> tensor<2xindex> + %18902 = shape.shape_of %18899 : tensor -> tensor<2xindex> + %18903 = shape.cstr_broadcastable %18901, %18902 : tensor<2xindex>, tensor<2xindex> + %18904 = shape.assuming %18903 -> (tensor) { + %19688 = shape.broadcast %18901, %18902 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18900, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18899, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18905 = shape.shape_of %18904 : tensor -> tensor<2xindex> + %18906 = shape.cstr_broadcastable %18905, %18902 : tensor<2xindex>, tensor<2xindex> + %18907 = shape.assuming %18906 -> (tensor) { + %19688 = shape.broadcast %18905, %18902 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18904, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18899, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18908 = stablehlo.dot %18907, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6656 = tensor.dim %18880, %c0 : tensor + %18909 = arith.index_cast %dim_6656 : index to i64 + %from_elements_6657 = tensor.from_elements %18909, %c1_i64 : tensor<2xi64> + %18910 = stablehlo.dynamic_reshape %18880, %from_elements_6657 : (tensor, tensor<2xi64>) -> tensor + %dim_6658 = tensor.dim %18877, %c0 : tensor + %18911 = arith.index_cast %dim_6658 : index to i64 + %from_elements_6659 = tensor.from_elements %18911, %c1_i64 : tensor<2xi64> + %18912 = stablehlo.dynamic_reshape %18877, %from_elements_6659 : (tensor, tensor<2xi64>) -> tensor + %18913 = stablehlo.concatenate %18910, %18912, dim = 1 : (tensor, tensor) -> tensor + %18914 = "stablehlo.gather"(%18593, %18913) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18915 = shape.shape_of %18908 : tensor -> tensor<2xindex> + %18916 = shape.shape_of %18914 : tensor -> tensor<2xindex> + %18917 = shape.cstr_broadcastable %18915, %18916 : tensor<2xindex>, tensor<2xindex> + %18918 = shape.assuming %18917 -> (tensor) { + %19688 = shape.broadcast %18915, %18916 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18908, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18914, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18919 = shape.shape_of %18918 : tensor -> tensor<2xindex> + %18920 = stablehlo.dynamic_broadcast_in_dim %18918, %18919, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18921 = stablehlo.dynamic_broadcast_in_dim %213, %18919, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18922 = stablehlo.multiply %18920, %18921 : tensor + %dim_6660 = tensor.dim %18882, %c0 : tensor + %18923 = arith.index_cast %dim_6660 : index to i64 + %dim_6661 = tensor.dim %18918, %c0 : tensor + %18924 = arith.index_cast %dim_6661 : index to i64 + %18925 = arith.maxsi %18923, %18924 : i64 + %18926 = arith.index_cast %18925 : i64 to index + %from_elements_6662 = tensor.from_elements %18926, %c4096 : tensor<2xindex> + %18927 = stablehlo.dynamic_broadcast_in_dim %18882, %from_elements_6662, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6663 = tensor.dim %18927, %c0 : tensor + %18928 = arith.index_cast %dim_6663 : index to i64 + %from_elements_6664 = tensor.from_elements %18928, %c4096_i64 : tensor<2xi64> + %18929 = stablehlo.real_dynamic_slice %18922, %c_22, %from_elements_6664, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6665 = tensor.from_elements %18928, %c4096_i64, %c1_i64 : tensor<3xi64> + %18930 = stablehlo.dynamic_reshape %18927, %from_elements_6665 : (tensor, tensor<3xi64>) -> tensor + %18931 = stablehlo.dynamic_iota %from_elements_6665, dim = 1 : (tensor<3xi64>) -> tensor + %18932 = stablehlo.concatenate %18930, %18931, dim = 2 : (tensor, tensor) -> tensor + %18933 = "stablehlo.scatter"(%18870, %18932, %18929) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18934 = stablehlo.slice %18553 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18935 = stablehlo.reshape %18934 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18936 = stablehlo.custom_call @byteir.non_zero(%18935) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6666 = tensor.dim %18936, %c0 : tensor + %18937 = arith.index_cast %dim_6666 : index to i64 + %from_elements_6667 = tensor.from_elements %18937, %c1_i64 : tensor<2xi64> + %18938 = stablehlo.real_dynamic_slice %18936, %c_22, %from_elements_6667, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6668 = tensor.dim %18938, %c0 : tensor + %18939 = arith.index_cast %dim_6668 : index to i64 + %from_elements_6669 = tensor.from_elements %18939 : tensor<1xi64> + %18940 = stablehlo.dynamic_reshape %18938, %from_elements_6669 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6670 = tensor.from_elements %18937, %c2_i64 : tensor<2xi64> + %18941 = stablehlo.real_dynamic_slice %18936, %c_24, %from_elements_6670, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6671 = tensor.dim %18941, %c0 : tensor + %18942 = arith.index_cast %dim_6671 : index to i64 + %from_elements_6672 = tensor.from_elements %18942 : tensor<1xi64> + %18943 = stablehlo.dynamic_reshape %18941, %from_elements_6672 : (tensor, tensor<1xi64>) -> tensor + %dim_6673 = tensor.dim %18943, %c0 : tensor + %18944 = arith.index_cast %dim_6673 : index to i64 + %from_elements_6674 = tensor.from_elements %18944, %c1_i64 : tensor<2xi64> + %18945 = stablehlo.dynamic_reshape %18943, %from_elements_6674 : (tensor, tensor<2xi64>) -> tensor + %dim_6675 = tensor.dim %18945, %c0 : tensor + %18946 = arith.index_cast %dim_6675 : index to i64 + %from_elements_6676 = tensor.from_elements %c1_i64, %18946, %c4096_i64 : tensor<3xi64> + %18947 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6676, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6677 = tensor.dim %18947, %c1 : tensor<1x?x4096xi64> + %18948 = arith.index_cast %dim_6677 : index to i64 + %from_elements_6678 = tensor.from_elements %c1_i64, %18948, %c4096_i64, %c1_i64 : tensor<4xi64> + %18949 = stablehlo.dynamic_reshape %18947, %from_elements_6678 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18950 = stablehlo.dynamic_broadcast_in_dim %18945, %from_elements_6676, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6679 = tensor.dim %18950, %c1 : tensor<1x?x4096xi64> + %18951 = arith.index_cast %dim_6679 : index to i64 + %from_elements_6680 = tensor.from_elements %c1_i64, %18951, %c4096_i64, %c1_i64 : tensor<4xi64> + %18952 = stablehlo.dynamic_reshape %18950, %from_elements_6680 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18953 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6676, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6681 = tensor.dim %18953, %c1 : tensor<1x?x4096xi64> + %18954 = arith.index_cast %dim_6681 : index to i64 + %from_elements_6682 = tensor.from_elements %c1_i64, %18954, %c4096_i64, %c1_i64 : tensor<4xi64> + %18955 = stablehlo.dynamic_reshape %18953, %from_elements_6682 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %18956 = stablehlo.concatenate %18949, %18952, %18955, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %18957 = "stablehlo.gather"(%18564, %18956) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %18958 = shape.shape_of %18957 : tensor<1x?x4096xf32> -> tensor<3xindex> + %18959 = shape.num_elements %18958 : tensor<3xindex> -> index + %18960 = stablehlo.compute_reshape_shape %18959, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %18961 = stablehlo.dynamic_reshape %18957, %18960 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %18962 = stablehlo.dot %18961, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %18963 = stablehlo.logistic %18962 : tensor + %18964 = shape.shape_of %18963 : tensor -> tensor<2xindex> + %18965 = shape.shape_of %18962 : tensor -> tensor<2xindex> + %18966 = shape.cstr_broadcastable %18964, %18965 : tensor<2xindex>, tensor<2xindex> + %18967 = shape.assuming %18966 -> (tensor) { + %19688 = shape.broadcast %18964, %18965 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18963, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18962, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18968 = shape.shape_of %18967 : tensor -> tensor<2xindex> + %18969 = shape.cstr_broadcastable %18968, %18965 : tensor<2xindex>, tensor<2xindex> + %18970 = shape.assuming %18969 -> (tensor) { + %19688 = shape.broadcast %18968, %18965 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18967, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18962, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18971 = stablehlo.dot %18970, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6683 = tensor.dim %18943, %c0 : tensor + %18972 = arith.index_cast %dim_6683 : index to i64 + %from_elements_6684 = tensor.from_elements %18972, %c1_i64 : tensor<2xi64> + %18973 = stablehlo.dynamic_reshape %18943, %from_elements_6684 : (tensor, tensor<2xi64>) -> tensor + %dim_6685 = tensor.dim %18940, %c0 : tensor + %18974 = arith.index_cast %dim_6685 : index to i64 + %from_elements_6686 = tensor.from_elements %18974, %c1_i64 : tensor<2xi64> + %18975 = stablehlo.dynamic_reshape %18940, %from_elements_6686 : (tensor, tensor<2xi64>) -> tensor + %18976 = stablehlo.concatenate %18973, %18975, dim = 1 : (tensor, tensor) -> tensor + %18977 = "stablehlo.gather"(%18593, %18976) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %18978 = shape.shape_of %18971 : tensor -> tensor<2xindex> + %18979 = shape.shape_of %18977 : tensor -> tensor<2xindex> + %18980 = shape.cstr_broadcastable %18978, %18979 : tensor<2xindex>, tensor<2xindex> + %18981 = shape.assuming %18980 -> (tensor) { + %19688 = shape.broadcast %18978, %18979 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %18971, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %18977, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %18982 = shape.shape_of %18981 : tensor -> tensor<2xindex> + %18983 = stablehlo.dynamic_broadcast_in_dim %18981, %18982, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %18984 = stablehlo.dynamic_broadcast_in_dim %213, %18982, dims = [] : (tensor, tensor<2xindex>) -> tensor + %18985 = stablehlo.multiply %18983, %18984 : tensor + %dim_6687 = tensor.dim %18945, %c0 : tensor + %18986 = arith.index_cast %dim_6687 : index to i64 + %dim_6688 = tensor.dim %18981, %c0 : tensor + %18987 = arith.index_cast %dim_6688 : index to i64 + %18988 = arith.maxsi %18986, %18987 : i64 + %18989 = arith.index_cast %18988 : i64 to index + %from_elements_6689 = tensor.from_elements %18989, %c4096 : tensor<2xindex> + %18990 = stablehlo.dynamic_broadcast_in_dim %18945, %from_elements_6689, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6690 = tensor.dim %18990, %c0 : tensor + %18991 = arith.index_cast %dim_6690 : index to i64 + %from_elements_6691 = tensor.from_elements %18991, %c4096_i64 : tensor<2xi64> + %18992 = stablehlo.real_dynamic_slice %18985, %c_22, %from_elements_6691, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6692 = tensor.from_elements %18991, %c4096_i64, %c1_i64 : tensor<3xi64> + %18993 = stablehlo.dynamic_reshape %18990, %from_elements_6692 : (tensor, tensor<3xi64>) -> tensor + %18994 = stablehlo.dynamic_iota %from_elements_6692, dim = 1 : (tensor<3xi64>) -> tensor + %18995 = stablehlo.concatenate %18993, %18994, dim = 2 : (tensor, tensor) -> tensor + %18996 = "stablehlo.scatter"(%18933, %18995, %18992) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %18997 = stablehlo.slice %18553 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %18998 = stablehlo.reshape %18997 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %18999 = stablehlo.custom_call @byteir.non_zero(%18998) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6693 = tensor.dim %18999, %c0 : tensor + %19000 = arith.index_cast %dim_6693 : index to i64 + %from_elements_6694 = tensor.from_elements %19000, %c1_i64 : tensor<2xi64> + %19001 = stablehlo.real_dynamic_slice %18999, %c_22, %from_elements_6694, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6695 = tensor.dim %19001, %c0 : tensor + %19002 = arith.index_cast %dim_6695 : index to i64 + %from_elements_6696 = tensor.from_elements %19002 : tensor<1xi64> + %19003 = stablehlo.dynamic_reshape %19001, %from_elements_6696 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6697 = tensor.from_elements %19000, %c2_i64 : tensor<2xi64> + %19004 = stablehlo.real_dynamic_slice %18999, %c_24, %from_elements_6697, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6698 = tensor.dim %19004, %c0 : tensor + %19005 = arith.index_cast %dim_6698 : index to i64 + %from_elements_6699 = tensor.from_elements %19005 : tensor<1xi64> + %19006 = stablehlo.dynamic_reshape %19004, %from_elements_6699 : (tensor, tensor<1xi64>) -> tensor + %dim_6700 = tensor.dim %19006, %c0 : tensor + %19007 = arith.index_cast %dim_6700 : index to i64 + %from_elements_6701 = tensor.from_elements %19007, %c1_i64 : tensor<2xi64> + %19008 = stablehlo.dynamic_reshape %19006, %from_elements_6701 : (tensor, tensor<2xi64>) -> tensor + %dim_6702 = tensor.dim %19008, %c0 : tensor + %19009 = arith.index_cast %dim_6702 : index to i64 + %from_elements_6703 = tensor.from_elements %c1_i64, %19009, %c4096_i64 : tensor<3xi64> + %19010 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6703, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6704 = tensor.dim %19010, %c1 : tensor<1x?x4096xi64> + %19011 = arith.index_cast %dim_6704 : index to i64 + %from_elements_6705 = tensor.from_elements %c1_i64, %19011, %c4096_i64, %c1_i64 : tensor<4xi64> + %19012 = stablehlo.dynamic_reshape %19010, %from_elements_6705 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19013 = stablehlo.dynamic_broadcast_in_dim %19008, %from_elements_6703, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6706 = tensor.dim %19013, %c1 : tensor<1x?x4096xi64> + %19014 = arith.index_cast %dim_6706 : index to i64 + %from_elements_6707 = tensor.from_elements %c1_i64, %19014, %c4096_i64, %c1_i64 : tensor<4xi64> + %19015 = stablehlo.dynamic_reshape %19013, %from_elements_6707 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19016 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6703, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6708 = tensor.dim %19016, %c1 : tensor<1x?x4096xi64> + %19017 = arith.index_cast %dim_6708 : index to i64 + %from_elements_6709 = tensor.from_elements %c1_i64, %19017, %c4096_i64, %c1_i64 : tensor<4xi64> + %19018 = stablehlo.dynamic_reshape %19016, %from_elements_6709 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19019 = stablehlo.concatenate %19012, %19015, %19018, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %19020 = "stablehlo.gather"(%18564, %19019) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %19021 = shape.shape_of %19020 : tensor<1x?x4096xf32> -> tensor<3xindex> + %19022 = shape.num_elements %19021 : tensor<3xindex> -> index + %19023 = stablehlo.compute_reshape_shape %19022, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %19024 = stablehlo.dynamic_reshape %19020, %19023 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %19025 = stablehlo.dot %19024, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %19026 = stablehlo.logistic %19025 : tensor + %19027 = shape.shape_of %19026 : tensor -> tensor<2xindex> + %19028 = shape.shape_of %19025 : tensor -> tensor<2xindex> + %19029 = shape.cstr_broadcastable %19027, %19028 : tensor<2xindex>, tensor<2xindex> + %19030 = shape.assuming %19029 -> (tensor) { + %19688 = shape.broadcast %19027, %19028 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19026, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19025, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19031 = shape.shape_of %19030 : tensor -> tensor<2xindex> + %19032 = shape.cstr_broadcastable %19031, %19028 : tensor<2xindex>, tensor<2xindex> + %19033 = shape.assuming %19032 -> (tensor) { + %19688 = shape.broadcast %19031, %19028 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19030, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19025, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19034 = stablehlo.dot %19033, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6710 = tensor.dim %19006, %c0 : tensor + %19035 = arith.index_cast %dim_6710 : index to i64 + %from_elements_6711 = tensor.from_elements %19035, %c1_i64 : tensor<2xi64> + %19036 = stablehlo.dynamic_reshape %19006, %from_elements_6711 : (tensor, tensor<2xi64>) -> tensor + %dim_6712 = tensor.dim %19003, %c0 : tensor + %19037 = arith.index_cast %dim_6712 : index to i64 + %from_elements_6713 = tensor.from_elements %19037, %c1_i64 : tensor<2xi64> + %19038 = stablehlo.dynamic_reshape %19003, %from_elements_6713 : (tensor, tensor<2xi64>) -> tensor + %19039 = stablehlo.concatenate %19036, %19038, dim = 1 : (tensor, tensor) -> tensor + %19040 = "stablehlo.gather"(%18593, %19039) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %19041 = shape.shape_of %19034 : tensor -> tensor<2xindex> + %19042 = shape.shape_of %19040 : tensor -> tensor<2xindex> + %19043 = shape.cstr_broadcastable %19041, %19042 : tensor<2xindex>, tensor<2xindex> + %19044 = shape.assuming %19043 -> (tensor) { + %19688 = shape.broadcast %19041, %19042 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19034, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19040, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19045 = shape.shape_of %19044 : tensor -> tensor<2xindex> + %19046 = stablehlo.dynamic_broadcast_in_dim %19044, %19045, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19047 = stablehlo.dynamic_broadcast_in_dim %213, %19045, dims = [] : (tensor, tensor<2xindex>) -> tensor + %19048 = stablehlo.multiply %19046, %19047 : tensor + %dim_6714 = tensor.dim %19008, %c0 : tensor + %19049 = arith.index_cast %dim_6714 : index to i64 + %dim_6715 = tensor.dim %19044, %c0 : tensor + %19050 = arith.index_cast %dim_6715 : index to i64 + %19051 = arith.maxsi %19049, %19050 : i64 + %19052 = arith.index_cast %19051 : i64 to index + %from_elements_6716 = tensor.from_elements %19052, %c4096 : tensor<2xindex> + %19053 = stablehlo.dynamic_broadcast_in_dim %19008, %from_elements_6716, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6717 = tensor.dim %19053, %c0 : tensor + %19054 = arith.index_cast %dim_6717 : index to i64 + %from_elements_6718 = tensor.from_elements %19054, %c4096_i64 : tensor<2xi64> + %19055 = stablehlo.real_dynamic_slice %19048, %c_22, %from_elements_6718, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6719 = tensor.from_elements %19054, %c4096_i64, %c1_i64 : tensor<3xi64> + %19056 = stablehlo.dynamic_reshape %19053, %from_elements_6719 : (tensor, tensor<3xi64>) -> tensor + %19057 = stablehlo.dynamic_iota %from_elements_6719, dim = 1 : (tensor<3xi64>) -> tensor + %19058 = stablehlo.concatenate %19056, %19057, dim = 2 : (tensor, tensor) -> tensor + %19059 = "stablehlo.scatter"(%18996, %19058, %19055) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %19060 = stablehlo.reshape %19059 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %19061 = stablehlo.add %18526, %19060 : tensor<3x1x4096xf32> + %19062 = stablehlo.broadcast_in_dim %19061, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %19063 = stablehlo.power %19062, %15 : tensor<3x1x4096xf32> + %19064 = stablehlo.reduce(%19063 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %19065 = stablehlo.reshape %19064 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %19066 = stablehlo.broadcast_in_dim %19065, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %19067 = stablehlo.divide %19066, %21 : tensor<3x1x1xf32> + %19068 = stablehlo.broadcast_in_dim %19067, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %19069 = stablehlo.add %19068, %25 : tensor<3x1x1xf32> + %19070 = stablehlo.rsqrt %19069 : tensor<3x1x1xf32> + %19071 = stablehlo.broadcast_in_dim %19070, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %19072 = stablehlo.multiply %19062, %19071 : tensor<3x1x4096xf32> + %19073 = stablehlo.broadcast_in_dim %19072, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %19074 = stablehlo.multiply %19073, %31 : tensor<3x1x4096xf32> + %19075 = stablehlo.reshape %19074 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %19076 = stablehlo.dot %19075, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %19077 = stablehlo.reshape %19076 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %19078 = stablehlo.dot %19075, %37 : (tensor<3x4096xf32>, tensor<4096x1024xf32>) -> tensor<3x1024xf32> + %19079 = stablehlo.reshape %19078 : (tensor<3x1024xf32>) -> tensor<3x1x1024xf32> + %19080 = stablehlo.reshape %19077 : (tensor<3x1x4096xf32>) -> tensor<3x1x32x128xf32> + %19081 = stablehlo.transpose %19080, dims = [0, 2, 1, 3] : (tensor<3x1x32x128xf32>) -> tensor<3x32x1x128xf32> + %19082 = stablehlo.reshape %19079 : (tensor<3x1x1024xf32>) -> tensor<3x1x8x128xf32> + %19083 = stablehlo.transpose %19082, dims = [0, 2, 1, 3] : (tensor<3x1x8x128xf32>) -> tensor<3x8x1x128xf32> + %19084 = stablehlo.slice %arg62 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %19085 = stablehlo.slice %arg63 [0:8, 0:128] : (tensor<131072x128xf32>) -> tensor<8x128xf32> + %19086 = "stablehlo.gather"(%19084, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %19087 = stablehlo.reshape %19086 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %19088 = "stablehlo.gather"(%19085, %46) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<8x128xf32>, tensor<1x1x1xi64>) -> tensor<1x1x128xf32> + %19089 = stablehlo.reshape %19088 : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32> + %19090 = stablehlo.broadcast_in_dim %19081, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %19091 = stablehlo.broadcast_in_dim %19087, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %19092 = stablehlo.multiply %19090, %19091 : tensor<3x32x1x128xf32> + %19093 = stablehlo.slice %19081 [0:3, 0:32, 0:1, 0:64] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %19094 = stablehlo.slice %19081 [0:3, 0:32, 0:1, 64:128] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x64xf32> + %19095 = stablehlo.negate %19094 : tensor<3x32x1x64xf32> + %19096 = stablehlo.concatenate %19095, %19093, dim = 3 : (tensor<3x32x1x64xf32>, tensor<3x32x1x64xf32>) -> tensor<3x32x1x128xf32> + %19097 = stablehlo.broadcast_in_dim %19096, dims = [0, 1, 2, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x32x1x128xf32> + %19098 = stablehlo.broadcast_in_dim %19089, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x32x1x128xf32> + %19099 = stablehlo.multiply %19097, %19098 : tensor<3x32x1x128xf32> + %19100 = stablehlo.add %19092, %19099 : tensor<3x32x1x128xf32> + %19101 = stablehlo.broadcast_in_dim %19083, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %19102 = stablehlo.broadcast_in_dim %19087, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %19103 = stablehlo.multiply %19101, %19102 : tensor<3x8x1x128xf32> + %19104 = stablehlo.slice %19083 [0:3, 0:8, 0:1, 0:64] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %19105 = stablehlo.slice %19083 [0:3, 0:8, 0:1, 64:128] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x64xf32> + %19106 = stablehlo.negate %19105 : tensor<3x8x1x64xf32> + %19107 = stablehlo.concatenate %19106, %19104, dim = 3 : (tensor<3x8x1x64xf32>, tensor<3x8x1x64xf32>) -> tensor<3x8x1x128xf32> + %19108 = stablehlo.broadcast_in_dim %19107, dims = [0, 1, 2, 3] : (tensor<3x8x1x128xf32>) -> tensor<3x8x1x128xf32> + %19109 = stablehlo.broadcast_in_dim %19089, dims = [0, 1, 2, 3] : (tensor<1x1x1x128xf32>) -> tensor<3x8x1x128xf32> + %19110 = stablehlo.multiply %19108, %19109 : tensor<3x8x1x128xf32> + %19111 = stablehlo.add %19103, %19110 : tensor<3x8x1x128xf32> + %19112 = stablehlo.concatenate %arg127, %19111, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %19113 = stablehlo.concatenate %arg128, %19083, dim = 2 : (tensor<3x8x7x128xf32>, tensor<3x8x1x128xf32>) -> tensor<3x8x8x128xf32> + %19114 = stablehlo.reshape %19112 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %19115 = stablehlo.broadcast_in_dim %19114, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %19116 = stablehlo.reshape %19115 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %19117 = stablehlo.reshape %19113 : (tensor<3x8x8x128xf32>) -> tensor<3x8x1x8x128xf32> + %19118 = stablehlo.broadcast_in_dim %19117, dims = [0, 1, 2, 3, 4] : (tensor<3x8x1x8x128xf32>) -> tensor<3x8x4x8x128xf32> + %19119 = stablehlo.reshape %19118 : (tensor<3x8x4x8x128xf32>) -> tensor<3x32x8x128xf32> + %19120 = stablehlo.transpose %19116, dims = [0, 1, 3, 2] : (tensor<3x32x8x128xf32>) -> tensor<3x32x128x8xf32> + %19121 = stablehlo.reshape %19100 : (tensor<3x32x1x128xf32>) -> tensor<96x1x128xf32> + %19122 = stablehlo.reshape %19120 : (tensor<3x32x128x8xf32>) -> tensor<96x128x8xf32> + %19123 = stablehlo.broadcast_in_dim %19122, dims = [0, 1, 2] : (tensor<96x128x8xf32>) -> tensor<96x128x8xf32> + %19124 = stablehlo.dot_general %19121, %19123, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x128xf32>, tensor<96x128x8xf32>) -> tensor<96x1x8xf32> + %19125 = stablehlo.reshape %19124 : (tensor<96x1x8xf32>) -> tensor<3x32x1x8xf32> + %19126 = stablehlo.broadcast_in_dim %19125, dims = [0, 1, 2, 3] : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %19127 = stablehlo.divide %19126, %89 : tensor<3x32x1x8xf32> + %19128 = stablehlo.custom_call @byteir.softmax(%19127) {byteir_attrs = {axis = 3 : i64}} : (tensor<3x32x1x8xf32>) -> tensor<3x32x1x8xf32> + %19129 = stablehlo.reshape %19128 : (tensor<3x32x1x8xf32>) -> tensor<96x1x8xf32> + %19130 = stablehlo.reshape %19119 : (tensor<3x32x8x128xf32>) -> tensor<96x8x128xf32> + %19131 = stablehlo.broadcast_in_dim %19130, dims = [0, 1, 2] : (tensor<96x8x128xf32>) -> tensor<96x8x128xf32> + %19132 = stablehlo.dot_general %19129, %19131, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<96x1x8xf32>, tensor<96x8x128xf32>) -> tensor<96x1x128xf32> + %19133 = stablehlo.reshape %19132 : (tensor<96x1x128xf32>) -> tensor<3x32x1x128xf32> + %19134 = stablehlo.transpose %19133, dims = [0, 2, 1, 3] : (tensor<3x32x1x128xf32>) -> tensor<3x1x32x128xf32> + %19135 = stablehlo.reshape %19134 : (tensor<3x1x32x128xf32>) -> tensor<3x1x4096xf32> + %19136 = stablehlo.reshape %19135 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %19137 = stablehlo.dot %19136, %33 : (tensor<3x4096xf32>, tensor<4096x4096xf32>) -> tensor<3x4096xf32> + %19138 = stablehlo.reshape %19137 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %19139 = stablehlo.add %19061, %19138 : tensor<3x1x4096xf32> + %19140 = stablehlo.broadcast_in_dim %19139, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %19141 = stablehlo.power %19140, %15 : tensor<3x1x4096xf32> + %19142 = stablehlo.reduce(%19141 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %19143 = stablehlo.reshape %19142 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %19144 = stablehlo.broadcast_in_dim %19143, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %19145 = stablehlo.divide %19144, %21 : tensor<3x1x1xf32> + %19146 = stablehlo.broadcast_in_dim %19145, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %19147 = stablehlo.add %19146, %25 : tensor<3x1x1xf32> + %19148 = stablehlo.rsqrt %19147 : tensor<3x1x1xf32> + %19149 = stablehlo.broadcast_in_dim %19148, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %19150 = stablehlo.multiply %19140, %19149 : tensor<3x1x4096xf32> + %19151 = stablehlo.broadcast_in_dim %19150, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %19152 = stablehlo.multiply %19151, %31 : tensor<3x1x4096xf32> + %19153 = stablehlo.reshape %19152 : (tensor<3x1x4096xf32>) -> tensor<3x4096xf32> + %19154 = stablehlo.dot %19153, %117 : (tensor<3x4096xf32>, tensor<4096x8xf32>) -> tensor<3x8xf32> + %19155 = stablehlo.custom_call @byteir.softmax(%19154) {byteir_attrs = {axis = 1 : i64}} : (tensor<3x8xf32>) -> tensor<3x8xf32> + %19156:2 = stablehlo.custom_call @byteir.top_k(%19155) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<3x8xf32>) -> (tensor<3x2xf32>, tensor<3x2xi64>) + %19157 = stablehlo.reduce(%19156#0 init: %cst_20) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor) -> tensor<3xf32> + %19158 = stablehlo.reshape %19157 : (tensor<3xf32>) -> tensor<3x1xf32> + %19159 = stablehlo.broadcast_in_dim %19156#0, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %19160 = stablehlo.broadcast_in_dim %19158, dims = [0, 1] : (tensor<3x1xf32>) -> tensor<3x2xf32> + %19161 = stablehlo.divide %19159, %19160 : tensor<3x2xf32> + %19162 = stablehlo.reshape %19156#1 : (tensor<3x2xi64>) -> tensor<3x2x1xi64> + %19163 = stablehlo.broadcast_in_dim %19162, dims = [0, 1, 2] : (tensor<3x2x1xi64>) -> tensor<3x2x8xi64> + %19164 = stablehlo.compare EQ, %19163, %137, SIGNED : (tensor<3x2x8xi64>, tensor<3x2x8xi64>) -> tensor<3x2x8xi1> + %19165 = stablehlo.convert %19164 : (tensor<3x2x8xi1>) -> tensor<3x2x8xi64> + %19166 = stablehlo.transpose %19165, dims = [2, 1, 0] : (tensor<3x2x8xi64>) -> tensor<8x2x3xi64> + %19167 = stablehlo.slice %19166 [0:1, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %19168 = stablehlo.reshape %19167 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %19169 = stablehlo.custom_call @byteir.non_zero(%19168) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6720 = tensor.dim %19169, %c0 : tensor + %19170 = arith.index_cast %dim_6720 : index to i64 + %from_elements_6721 = tensor.from_elements %19170, %c1_i64 : tensor<2xi64> + %19171 = stablehlo.real_dynamic_slice %19169, %c_22, %from_elements_6721, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6722 = tensor.dim %19171, %c0 : tensor + %19172 = arith.index_cast %dim_6722 : index to i64 + %from_elements_6723 = tensor.from_elements %19172 : tensor<1xi64> + %19173 = stablehlo.dynamic_reshape %19171, %from_elements_6723 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6724 = tensor.from_elements %19170, %c2_i64 : tensor<2xi64> + %19174 = stablehlo.real_dynamic_slice %19169, %c_24, %from_elements_6724, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6725 = tensor.dim %19174, %c0 : tensor + %19175 = arith.index_cast %dim_6725 : index to i64 + %from_elements_6726 = tensor.from_elements %19175 : tensor<1xi64> + %19176 = stablehlo.dynamic_reshape %19174, %from_elements_6726 : (tensor, tensor<1xi64>) -> tensor + %19177 = stablehlo.reshape %19153 : (tensor<3x4096xf32>) -> tensor<1x3x4096xf32> + %dim_6727 = tensor.dim %19176, %c0 : tensor + %19178 = arith.index_cast %dim_6727 : index to i64 + %from_elements_6728 = tensor.from_elements %19178, %c1_i64 : tensor<2xi64> + %19179 = stablehlo.dynamic_reshape %19176, %from_elements_6728 : (tensor, tensor<2xi64>) -> tensor + %dim_6729 = tensor.dim %19179, %c0 : tensor + %19180 = arith.index_cast %dim_6729 : index to i64 + %from_elements_6730 = tensor.from_elements %c1_i64, %19180, %c4096_i64 : tensor<3xi64> + %19181 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6730, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6731 = tensor.dim %19181, %c1 : tensor<1x?x4096xi64> + %19182 = arith.index_cast %dim_6731 : index to i64 + %from_elements_6732 = tensor.from_elements %c1_i64, %19182, %c4096_i64, %c1_i64 : tensor<4xi64> + %19183 = stablehlo.dynamic_reshape %19181, %from_elements_6732 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19184 = stablehlo.dynamic_broadcast_in_dim %19179, %from_elements_6730, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6733 = tensor.dim %19184, %c1 : tensor<1x?x4096xi64> + %19185 = arith.index_cast %dim_6733 : index to i64 + %from_elements_6734 = tensor.from_elements %c1_i64, %19185, %c4096_i64, %c1_i64 : tensor<4xi64> + %19186 = stablehlo.dynamic_reshape %19184, %from_elements_6734 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19187 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6730, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6735 = tensor.dim %19187, %c1 : tensor<1x?x4096xi64> + %19188 = arith.index_cast %dim_6735 : index to i64 + %from_elements_6736 = tensor.from_elements %c1_i64, %19188, %c4096_i64, %c1_i64 : tensor<4xi64> + %19189 = stablehlo.dynamic_reshape %19187, %from_elements_6736 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19190 = stablehlo.concatenate %19183, %19186, %19189, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %19191 = "stablehlo.gather"(%19177, %19190) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %19192 = shape.shape_of %19191 : tensor<1x?x4096xf32> -> tensor<3xindex> + %19193 = shape.num_elements %19192 : tensor<3xindex> -> index + %19194 = stablehlo.compute_reshape_shape %19193, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %19195 = stablehlo.dynamic_reshape %19191, %19194 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %19196 = stablehlo.dot %19195, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %19197 = stablehlo.logistic %19196 : tensor + %19198 = shape.shape_of %19197 : tensor -> tensor<2xindex> + %19199 = shape.shape_of %19196 : tensor -> tensor<2xindex> + %19200 = shape.cstr_broadcastable %19198, %19199 : tensor<2xindex>, tensor<2xindex> + %19201 = shape.assuming %19200 -> (tensor) { + %19688 = shape.broadcast %19198, %19199 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19197, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19196, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19202 = shape.shape_of %19201 : tensor -> tensor<2xindex> + %19203 = shape.cstr_broadcastable %19202, %19199 : tensor<2xindex>, tensor<2xindex> + %19204 = shape.assuming %19203 -> (tensor) { + %19688 = shape.broadcast %19202, %19199 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19201, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19196, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19205 = stablehlo.dot %19204, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %19206 = stablehlo.reshape %19161 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %dim_6737 = tensor.dim %19176, %c0 : tensor + %19207 = arith.index_cast %dim_6737 : index to i64 + %from_elements_6738 = tensor.from_elements %19207, %c1_i64 : tensor<2xi64> + %19208 = stablehlo.dynamic_reshape %19176, %from_elements_6738 : (tensor, tensor<2xi64>) -> tensor + %dim_6739 = tensor.dim %19173, %c0 : tensor + %19209 = arith.index_cast %dim_6739 : index to i64 + %from_elements_6740 = tensor.from_elements %19209, %c1_i64 : tensor<2xi64> + %19210 = stablehlo.dynamic_reshape %19173, %from_elements_6740 : (tensor, tensor<2xi64>) -> tensor + %19211 = stablehlo.concatenate %19208, %19210, dim = 1 : (tensor, tensor) -> tensor + %19212 = "stablehlo.gather"(%19206, %19211) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %19213 = shape.shape_of %19205 : tensor -> tensor<2xindex> + %19214 = shape.shape_of %19212 : tensor -> tensor<2xindex> + %19215 = shape.cstr_broadcastable %19213, %19214 : tensor<2xindex>, tensor<2xindex> + %19216 = shape.assuming %19215 -> (tensor) { + %19688 = shape.broadcast %19213, %19214 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19205, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19212, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19217 = shape.shape_of %19216 : tensor -> tensor<2xindex> + %19218 = stablehlo.dynamic_broadcast_in_dim %19216, %19217, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19219 = stablehlo.dynamic_broadcast_in_dim %213, %19217, dims = [] : (tensor, tensor<2xindex>) -> tensor + %19220 = stablehlo.multiply %19218, %19219 : tensor + %dim_6741 = tensor.dim %19179, %c0 : tensor + %19221 = arith.index_cast %dim_6741 : index to i64 + %dim_6742 = tensor.dim %19216, %c0 : tensor + %19222 = arith.index_cast %dim_6742 : index to i64 + %19223 = arith.maxsi %19221, %19222 : i64 + %19224 = arith.index_cast %19223 : i64 to index + %from_elements_6743 = tensor.from_elements %19224, %c4096 : tensor<2xindex> + %19225 = stablehlo.dynamic_broadcast_in_dim %19179, %from_elements_6743, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6744 = tensor.dim %19225, %c0 : tensor + %19226 = arith.index_cast %dim_6744 : index to i64 + %from_elements_6745 = tensor.from_elements %19226, %c4096_i64 : tensor<2xi64> + %19227 = stablehlo.real_dynamic_slice %19220, %c_22, %from_elements_6745, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6746 = tensor.from_elements %19226, %c4096_i64, %c1_i64 : tensor<3xi64> + %19228 = stablehlo.dynamic_reshape %19225, %from_elements_6746 : (tensor, tensor<3xi64>) -> tensor + %19229 = stablehlo.dynamic_iota %from_elements_6746, dim = 1 : (tensor<3xi64>) -> tensor + %19230 = stablehlo.concatenate %19228, %19229, dim = 2 : (tensor, tensor) -> tensor + %19231 = "stablehlo.scatter"(%cst_2, %19230, %19227) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %19232 = stablehlo.slice %19166 [1:2, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %19233 = stablehlo.reshape %19232 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %19234 = stablehlo.custom_call @byteir.non_zero(%19233) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6747 = tensor.dim %19234, %c0 : tensor + %19235 = arith.index_cast %dim_6747 : index to i64 + %from_elements_6748 = tensor.from_elements %19235, %c1_i64 : tensor<2xi64> + %19236 = stablehlo.real_dynamic_slice %19234, %c_22, %from_elements_6748, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6749 = tensor.dim %19236, %c0 : tensor + %19237 = arith.index_cast %dim_6749 : index to i64 + %from_elements_6750 = tensor.from_elements %19237 : tensor<1xi64> + %19238 = stablehlo.dynamic_reshape %19236, %from_elements_6750 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6751 = tensor.from_elements %19235, %c2_i64 : tensor<2xi64> + %19239 = stablehlo.real_dynamic_slice %19234, %c_24, %from_elements_6751, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6752 = tensor.dim %19239, %c0 : tensor + %19240 = arith.index_cast %dim_6752 : index to i64 + %from_elements_6753 = tensor.from_elements %19240 : tensor<1xi64> + %19241 = stablehlo.dynamic_reshape %19239, %from_elements_6753 : (tensor, tensor<1xi64>) -> tensor + %dim_6754 = tensor.dim %19241, %c0 : tensor + %19242 = arith.index_cast %dim_6754 : index to i64 + %from_elements_6755 = tensor.from_elements %19242, %c1_i64 : tensor<2xi64> + %19243 = stablehlo.dynamic_reshape %19241, %from_elements_6755 : (tensor, tensor<2xi64>) -> tensor + %dim_6756 = tensor.dim %19243, %c0 : tensor + %19244 = arith.index_cast %dim_6756 : index to i64 + %from_elements_6757 = tensor.from_elements %c1_i64, %19244, %c4096_i64 : tensor<3xi64> + %19245 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6757, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6758 = tensor.dim %19245, %c1 : tensor<1x?x4096xi64> + %19246 = arith.index_cast %dim_6758 : index to i64 + %from_elements_6759 = tensor.from_elements %c1_i64, %19246, %c4096_i64, %c1_i64 : tensor<4xi64> + %19247 = stablehlo.dynamic_reshape %19245, %from_elements_6759 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19248 = stablehlo.dynamic_broadcast_in_dim %19243, %from_elements_6757, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6760 = tensor.dim %19248, %c1 : tensor<1x?x4096xi64> + %19249 = arith.index_cast %dim_6760 : index to i64 + %from_elements_6761 = tensor.from_elements %c1_i64, %19249, %c4096_i64, %c1_i64 : tensor<4xi64> + %19250 = stablehlo.dynamic_reshape %19248, %from_elements_6761 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19251 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6757, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6762 = tensor.dim %19251, %c1 : tensor<1x?x4096xi64> + %19252 = arith.index_cast %dim_6762 : index to i64 + %from_elements_6763 = tensor.from_elements %c1_i64, %19252, %c4096_i64, %c1_i64 : tensor<4xi64> + %19253 = stablehlo.dynamic_reshape %19251, %from_elements_6763 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19254 = stablehlo.concatenate %19247, %19250, %19253, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %19255 = "stablehlo.gather"(%19177, %19254) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %19256 = shape.shape_of %19255 : tensor<1x?x4096xf32> -> tensor<3xindex> + %19257 = shape.num_elements %19256 : tensor<3xindex> -> index + %19258 = stablehlo.compute_reshape_shape %19257, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %19259 = stablehlo.dynamic_reshape %19255, %19258 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %19260 = stablehlo.dot %19259, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %19261 = stablehlo.logistic %19260 : tensor + %19262 = shape.shape_of %19261 : tensor -> tensor<2xindex> + %19263 = shape.shape_of %19260 : tensor -> tensor<2xindex> + %19264 = shape.cstr_broadcastable %19262, %19263 : tensor<2xindex>, tensor<2xindex> + %19265 = shape.assuming %19264 -> (tensor) { + %19688 = shape.broadcast %19262, %19263 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19261, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19260, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19266 = shape.shape_of %19265 : tensor -> tensor<2xindex> + %19267 = shape.cstr_broadcastable %19266, %19263 : tensor<2xindex>, tensor<2xindex> + %19268 = shape.assuming %19267 -> (tensor) { + %19688 = shape.broadcast %19266, %19263 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19265, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19260, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19269 = stablehlo.dot %19268, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6764 = tensor.dim %19241, %c0 : tensor + %19270 = arith.index_cast %dim_6764 : index to i64 + %from_elements_6765 = tensor.from_elements %19270, %c1_i64 : tensor<2xi64> + %19271 = stablehlo.dynamic_reshape %19241, %from_elements_6765 : (tensor, tensor<2xi64>) -> tensor + %dim_6766 = tensor.dim %19238, %c0 : tensor + %19272 = arith.index_cast %dim_6766 : index to i64 + %from_elements_6767 = tensor.from_elements %19272, %c1_i64 : tensor<2xi64> + %19273 = stablehlo.dynamic_reshape %19238, %from_elements_6767 : (tensor, tensor<2xi64>) -> tensor + %19274 = stablehlo.concatenate %19271, %19273, dim = 1 : (tensor, tensor) -> tensor + %19275 = "stablehlo.gather"(%19206, %19274) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %19276 = shape.shape_of %19269 : tensor -> tensor<2xindex> + %19277 = shape.shape_of %19275 : tensor -> tensor<2xindex> + %19278 = shape.cstr_broadcastable %19276, %19277 : tensor<2xindex>, tensor<2xindex> + %19279 = shape.assuming %19278 -> (tensor) { + %19688 = shape.broadcast %19276, %19277 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19269, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19275, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19280 = shape.shape_of %19279 : tensor -> tensor<2xindex> + %19281 = stablehlo.dynamic_broadcast_in_dim %19279, %19280, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19282 = stablehlo.dynamic_broadcast_in_dim %213, %19280, dims = [] : (tensor, tensor<2xindex>) -> tensor + %19283 = stablehlo.multiply %19281, %19282 : tensor + %dim_6768 = tensor.dim %19243, %c0 : tensor + %19284 = arith.index_cast %dim_6768 : index to i64 + %dim_6769 = tensor.dim %19279, %c0 : tensor + %19285 = arith.index_cast %dim_6769 : index to i64 + %19286 = arith.maxsi %19284, %19285 : i64 + %19287 = arith.index_cast %19286 : i64 to index + %from_elements_6770 = tensor.from_elements %19287, %c4096 : tensor<2xindex> + %19288 = stablehlo.dynamic_broadcast_in_dim %19243, %from_elements_6770, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6771 = tensor.dim %19288, %c0 : tensor + %19289 = arith.index_cast %dim_6771 : index to i64 + %from_elements_6772 = tensor.from_elements %19289, %c4096_i64 : tensor<2xi64> + %19290 = stablehlo.real_dynamic_slice %19283, %c_22, %from_elements_6772, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6773 = tensor.from_elements %19289, %c4096_i64, %c1_i64 : tensor<3xi64> + %19291 = stablehlo.dynamic_reshape %19288, %from_elements_6773 : (tensor, tensor<3xi64>) -> tensor + %19292 = stablehlo.dynamic_iota %from_elements_6773, dim = 1 : (tensor<3xi64>) -> tensor + %19293 = stablehlo.concatenate %19291, %19292, dim = 2 : (tensor, tensor) -> tensor + %19294 = "stablehlo.scatter"(%19231, %19293, %19290) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %19295 = stablehlo.slice %19166 [2:3, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %19296 = stablehlo.reshape %19295 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %19297 = stablehlo.custom_call @byteir.non_zero(%19296) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6774 = tensor.dim %19297, %c0 : tensor + %19298 = arith.index_cast %dim_6774 : index to i64 + %from_elements_6775 = tensor.from_elements %19298, %c1_i64 : tensor<2xi64> + %19299 = stablehlo.real_dynamic_slice %19297, %c_22, %from_elements_6775, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6776 = tensor.dim %19299, %c0 : tensor + %19300 = arith.index_cast %dim_6776 : index to i64 + %from_elements_6777 = tensor.from_elements %19300 : tensor<1xi64> + %19301 = stablehlo.dynamic_reshape %19299, %from_elements_6777 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6778 = tensor.from_elements %19298, %c2_i64 : tensor<2xi64> + %19302 = stablehlo.real_dynamic_slice %19297, %c_24, %from_elements_6778, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6779 = tensor.dim %19302, %c0 : tensor + %19303 = arith.index_cast %dim_6779 : index to i64 + %from_elements_6780 = tensor.from_elements %19303 : tensor<1xi64> + %19304 = stablehlo.dynamic_reshape %19302, %from_elements_6780 : (tensor, tensor<1xi64>) -> tensor + %dim_6781 = tensor.dim %19304, %c0 : tensor + %19305 = arith.index_cast %dim_6781 : index to i64 + %from_elements_6782 = tensor.from_elements %19305, %c1_i64 : tensor<2xi64> + %19306 = stablehlo.dynamic_reshape %19304, %from_elements_6782 : (tensor, tensor<2xi64>) -> tensor + %dim_6783 = tensor.dim %19306, %c0 : tensor + %19307 = arith.index_cast %dim_6783 : index to i64 + %from_elements_6784 = tensor.from_elements %c1_i64, %19307, %c4096_i64 : tensor<3xi64> + %19308 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6784, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6785 = tensor.dim %19308, %c1 : tensor<1x?x4096xi64> + %19309 = arith.index_cast %dim_6785 : index to i64 + %from_elements_6786 = tensor.from_elements %c1_i64, %19309, %c4096_i64, %c1_i64 : tensor<4xi64> + %19310 = stablehlo.dynamic_reshape %19308, %from_elements_6786 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19311 = stablehlo.dynamic_broadcast_in_dim %19306, %from_elements_6784, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6787 = tensor.dim %19311, %c1 : tensor<1x?x4096xi64> + %19312 = arith.index_cast %dim_6787 : index to i64 + %from_elements_6788 = tensor.from_elements %c1_i64, %19312, %c4096_i64, %c1_i64 : tensor<4xi64> + %19313 = stablehlo.dynamic_reshape %19311, %from_elements_6788 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19314 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6784, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6789 = tensor.dim %19314, %c1 : tensor<1x?x4096xi64> + %19315 = arith.index_cast %dim_6789 : index to i64 + %from_elements_6790 = tensor.from_elements %c1_i64, %19315, %c4096_i64, %c1_i64 : tensor<4xi64> + %19316 = stablehlo.dynamic_reshape %19314, %from_elements_6790 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19317 = stablehlo.concatenate %19310, %19313, %19316, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %19318 = "stablehlo.gather"(%19177, %19317) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %19319 = shape.shape_of %19318 : tensor<1x?x4096xf32> -> tensor<3xindex> + %19320 = shape.num_elements %19319 : tensor<3xindex> -> index + %19321 = stablehlo.compute_reshape_shape %19320, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %19322 = stablehlo.dynamic_reshape %19318, %19321 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %19323 = stablehlo.dot %19322, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %19324 = stablehlo.logistic %19323 : tensor + %19325 = shape.shape_of %19324 : tensor -> tensor<2xindex> + %19326 = shape.shape_of %19323 : tensor -> tensor<2xindex> + %19327 = shape.cstr_broadcastable %19325, %19326 : tensor<2xindex>, tensor<2xindex> + %19328 = shape.assuming %19327 -> (tensor) { + %19688 = shape.broadcast %19325, %19326 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19324, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19323, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19329 = shape.shape_of %19328 : tensor -> tensor<2xindex> + %19330 = shape.cstr_broadcastable %19329, %19326 : tensor<2xindex>, tensor<2xindex> + %19331 = shape.assuming %19330 -> (tensor) { + %19688 = shape.broadcast %19329, %19326 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19328, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19323, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19332 = stablehlo.dot %19331, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6791 = tensor.dim %19304, %c0 : tensor + %19333 = arith.index_cast %dim_6791 : index to i64 + %from_elements_6792 = tensor.from_elements %19333, %c1_i64 : tensor<2xi64> + %19334 = stablehlo.dynamic_reshape %19304, %from_elements_6792 : (tensor, tensor<2xi64>) -> tensor + %dim_6793 = tensor.dim %19301, %c0 : tensor + %19335 = arith.index_cast %dim_6793 : index to i64 + %from_elements_6794 = tensor.from_elements %19335, %c1_i64 : tensor<2xi64> + %19336 = stablehlo.dynamic_reshape %19301, %from_elements_6794 : (tensor, tensor<2xi64>) -> tensor + %19337 = stablehlo.concatenate %19334, %19336, dim = 1 : (tensor, tensor) -> tensor + %19338 = "stablehlo.gather"(%19206, %19337) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %19339 = shape.shape_of %19332 : tensor -> tensor<2xindex> + %19340 = shape.shape_of %19338 : tensor -> tensor<2xindex> + %19341 = shape.cstr_broadcastable %19339, %19340 : tensor<2xindex>, tensor<2xindex> + %19342 = shape.assuming %19341 -> (tensor) { + %19688 = shape.broadcast %19339, %19340 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19332, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19338, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19343 = shape.shape_of %19342 : tensor -> tensor<2xindex> + %19344 = stablehlo.dynamic_broadcast_in_dim %19342, %19343, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19345 = stablehlo.dynamic_broadcast_in_dim %213, %19343, dims = [] : (tensor, tensor<2xindex>) -> tensor + %19346 = stablehlo.multiply %19344, %19345 : tensor + %dim_6795 = tensor.dim %19306, %c0 : tensor + %19347 = arith.index_cast %dim_6795 : index to i64 + %dim_6796 = tensor.dim %19342, %c0 : tensor + %19348 = arith.index_cast %dim_6796 : index to i64 + %19349 = arith.maxsi %19347, %19348 : i64 + %19350 = arith.index_cast %19349 : i64 to index + %from_elements_6797 = tensor.from_elements %19350, %c4096 : tensor<2xindex> + %19351 = stablehlo.dynamic_broadcast_in_dim %19306, %from_elements_6797, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6798 = tensor.dim %19351, %c0 : tensor + %19352 = arith.index_cast %dim_6798 : index to i64 + %from_elements_6799 = tensor.from_elements %19352, %c4096_i64 : tensor<2xi64> + %19353 = stablehlo.real_dynamic_slice %19346, %c_22, %from_elements_6799, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6800 = tensor.from_elements %19352, %c4096_i64, %c1_i64 : tensor<3xi64> + %19354 = stablehlo.dynamic_reshape %19351, %from_elements_6800 : (tensor, tensor<3xi64>) -> tensor + %19355 = stablehlo.dynamic_iota %from_elements_6800, dim = 1 : (tensor<3xi64>) -> tensor + %19356 = stablehlo.concatenate %19354, %19355, dim = 2 : (tensor, tensor) -> tensor + %19357 = "stablehlo.scatter"(%19294, %19356, %19353) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %19358 = stablehlo.slice %19166 [3:4, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %19359 = stablehlo.reshape %19358 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %19360 = stablehlo.custom_call @byteir.non_zero(%19359) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6801 = tensor.dim %19360, %c0 : tensor + %19361 = arith.index_cast %dim_6801 : index to i64 + %from_elements_6802 = tensor.from_elements %19361, %c1_i64 : tensor<2xi64> + %19362 = stablehlo.real_dynamic_slice %19360, %c_22, %from_elements_6802, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6803 = tensor.dim %19362, %c0 : tensor + %19363 = arith.index_cast %dim_6803 : index to i64 + %from_elements_6804 = tensor.from_elements %19363 : tensor<1xi64> + %19364 = stablehlo.dynamic_reshape %19362, %from_elements_6804 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6805 = tensor.from_elements %19361, %c2_i64 : tensor<2xi64> + %19365 = stablehlo.real_dynamic_slice %19360, %c_24, %from_elements_6805, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6806 = tensor.dim %19365, %c0 : tensor + %19366 = arith.index_cast %dim_6806 : index to i64 + %from_elements_6807 = tensor.from_elements %19366 : tensor<1xi64> + %19367 = stablehlo.dynamic_reshape %19365, %from_elements_6807 : (tensor, tensor<1xi64>) -> tensor + %dim_6808 = tensor.dim %19367, %c0 : tensor + %19368 = arith.index_cast %dim_6808 : index to i64 + %from_elements_6809 = tensor.from_elements %19368, %c1_i64 : tensor<2xi64> + %19369 = stablehlo.dynamic_reshape %19367, %from_elements_6809 : (tensor, tensor<2xi64>) -> tensor + %dim_6810 = tensor.dim %19369, %c0 : tensor + %19370 = arith.index_cast %dim_6810 : index to i64 + %from_elements_6811 = tensor.from_elements %c1_i64, %19370, %c4096_i64 : tensor<3xi64> + %19371 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6811, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6812 = tensor.dim %19371, %c1 : tensor<1x?x4096xi64> + %19372 = arith.index_cast %dim_6812 : index to i64 + %from_elements_6813 = tensor.from_elements %c1_i64, %19372, %c4096_i64, %c1_i64 : tensor<4xi64> + %19373 = stablehlo.dynamic_reshape %19371, %from_elements_6813 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19374 = stablehlo.dynamic_broadcast_in_dim %19369, %from_elements_6811, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6814 = tensor.dim %19374, %c1 : tensor<1x?x4096xi64> + %19375 = arith.index_cast %dim_6814 : index to i64 + %from_elements_6815 = tensor.from_elements %c1_i64, %19375, %c4096_i64, %c1_i64 : tensor<4xi64> + %19376 = stablehlo.dynamic_reshape %19374, %from_elements_6815 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19377 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6811, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6816 = tensor.dim %19377, %c1 : tensor<1x?x4096xi64> + %19378 = arith.index_cast %dim_6816 : index to i64 + %from_elements_6817 = tensor.from_elements %c1_i64, %19378, %c4096_i64, %c1_i64 : tensor<4xi64> + %19379 = stablehlo.dynamic_reshape %19377, %from_elements_6817 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19380 = stablehlo.concatenate %19373, %19376, %19379, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %19381 = "stablehlo.gather"(%19177, %19380) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %19382 = shape.shape_of %19381 : tensor<1x?x4096xf32> -> tensor<3xindex> + %19383 = shape.num_elements %19382 : tensor<3xindex> -> index + %19384 = stablehlo.compute_reshape_shape %19383, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %19385 = stablehlo.dynamic_reshape %19381, %19384 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %19386 = stablehlo.dot %19385, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %19387 = stablehlo.logistic %19386 : tensor + %19388 = shape.shape_of %19387 : tensor -> tensor<2xindex> + %19389 = shape.shape_of %19386 : tensor -> tensor<2xindex> + %19390 = shape.cstr_broadcastable %19388, %19389 : tensor<2xindex>, tensor<2xindex> + %19391 = shape.assuming %19390 -> (tensor) { + %19688 = shape.broadcast %19388, %19389 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19387, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19386, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19392 = shape.shape_of %19391 : tensor -> tensor<2xindex> + %19393 = shape.cstr_broadcastable %19392, %19389 : tensor<2xindex>, tensor<2xindex> + %19394 = shape.assuming %19393 -> (tensor) { + %19688 = shape.broadcast %19392, %19389 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19391, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19386, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19395 = stablehlo.dot %19394, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6818 = tensor.dim %19367, %c0 : tensor + %19396 = arith.index_cast %dim_6818 : index to i64 + %from_elements_6819 = tensor.from_elements %19396, %c1_i64 : tensor<2xi64> + %19397 = stablehlo.dynamic_reshape %19367, %from_elements_6819 : (tensor, tensor<2xi64>) -> tensor + %dim_6820 = tensor.dim %19364, %c0 : tensor + %19398 = arith.index_cast %dim_6820 : index to i64 + %from_elements_6821 = tensor.from_elements %19398, %c1_i64 : tensor<2xi64> + %19399 = stablehlo.dynamic_reshape %19364, %from_elements_6821 : (tensor, tensor<2xi64>) -> tensor + %19400 = stablehlo.concatenate %19397, %19399, dim = 1 : (tensor, tensor) -> tensor + %19401 = "stablehlo.gather"(%19206, %19400) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %19402 = shape.shape_of %19395 : tensor -> tensor<2xindex> + %19403 = shape.shape_of %19401 : tensor -> tensor<2xindex> + %19404 = shape.cstr_broadcastable %19402, %19403 : tensor<2xindex>, tensor<2xindex> + %19405 = shape.assuming %19404 -> (tensor) { + %19688 = shape.broadcast %19402, %19403 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19395, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19401, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19406 = shape.shape_of %19405 : tensor -> tensor<2xindex> + %19407 = stablehlo.dynamic_broadcast_in_dim %19405, %19406, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19408 = stablehlo.dynamic_broadcast_in_dim %213, %19406, dims = [] : (tensor, tensor<2xindex>) -> tensor + %19409 = stablehlo.multiply %19407, %19408 : tensor + %dim_6822 = tensor.dim %19369, %c0 : tensor + %19410 = arith.index_cast %dim_6822 : index to i64 + %dim_6823 = tensor.dim %19405, %c0 : tensor + %19411 = arith.index_cast %dim_6823 : index to i64 + %19412 = arith.maxsi %19410, %19411 : i64 + %19413 = arith.index_cast %19412 : i64 to index + %from_elements_6824 = tensor.from_elements %19413, %c4096 : tensor<2xindex> + %19414 = stablehlo.dynamic_broadcast_in_dim %19369, %from_elements_6824, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6825 = tensor.dim %19414, %c0 : tensor + %19415 = arith.index_cast %dim_6825 : index to i64 + %from_elements_6826 = tensor.from_elements %19415, %c4096_i64 : tensor<2xi64> + %19416 = stablehlo.real_dynamic_slice %19409, %c_22, %from_elements_6826, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6827 = tensor.from_elements %19415, %c4096_i64, %c1_i64 : tensor<3xi64> + %19417 = stablehlo.dynamic_reshape %19414, %from_elements_6827 : (tensor, tensor<3xi64>) -> tensor + %19418 = stablehlo.dynamic_iota %from_elements_6827, dim = 1 : (tensor<3xi64>) -> tensor + %19419 = stablehlo.concatenate %19417, %19418, dim = 2 : (tensor, tensor) -> tensor + %19420 = "stablehlo.scatter"(%19357, %19419, %19416) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %19421 = stablehlo.slice %19166 [4:5, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %19422 = stablehlo.reshape %19421 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %19423 = stablehlo.custom_call @byteir.non_zero(%19422) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6828 = tensor.dim %19423, %c0 : tensor + %19424 = arith.index_cast %dim_6828 : index to i64 + %from_elements_6829 = tensor.from_elements %19424, %c1_i64 : tensor<2xi64> + %19425 = stablehlo.real_dynamic_slice %19423, %c_22, %from_elements_6829, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6830 = tensor.dim %19425, %c0 : tensor + %19426 = arith.index_cast %dim_6830 : index to i64 + %from_elements_6831 = tensor.from_elements %19426 : tensor<1xi64> + %19427 = stablehlo.dynamic_reshape %19425, %from_elements_6831 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6832 = tensor.from_elements %19424, %c2_i64 : tensor<2xi64> + %19428 = stablehlo.real_dynamic_slice %19423, %c_24, %from_elements_6832, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6833 = tensor.dim %19428, %c0 : tensor + %19429 = arith.index_cast %dim_6833 : index to i64 + %from_elements_6834 = tensor.from_elements %19429 : tensor<1xi64> + %19430 = stablehlo.dynamic_reshape %19428, %from_elements_6834 : (tensor, tensor<1xi64>) -> tensor + %dim_6835 = tensor.dim %19430, %c0 : tensor + %19431 = arith.index_cast %dim_6835 : index to i64 + %from_elements_6836 = tensor.from_elements %19431, %c1_i64 : tensor<2xi64> + %19432 = stablehlo.dynamic_reshape %19430, %from_elements_6836 : (tensor, tensor<2xi64>) -> tensor + %dim_6837 = tensor.dim %19432, %c0 : tensor + %19433 = arith.index_cast %dim_6837 : index to i64 + %from_elements_6838 = tensor.from_elements %c1_i64, %19433, %c4096_i64 : tensor<3xi64> + %19434 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6838, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6839 = tensor.dim %19434, %c1 : tensor<1x?x4096xi64> + %19435 = arith.index_cast %dim_6839 : index to i64 + %from_elements_6840 = tensor.from_elements %c1_i64, %19435, %c4096_i64, %c1_i64 : tensor<4xi64> + %19436 = stablehlo.dynamic_reshape %19434, %from_elements_6840 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19437 = stablehlo.dynamic_broadcast_in_dim %19432, %from_elements_6838, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6841 = tensor.dim %19437, %c1 : tensor<1x?x4096xi64> + %19438 = arith.index_cast %dim_6841 : index to i64 + %from_elements_6842 = tensor.from_elements %c1_i64, %19438, %c4096_i64, %c1_i64 : tensor<4xi64> + %19439 = stablehlo.dynamic_reshape %19437, %from_elements_6842 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19440 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6838, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6843 = tensor.dim %19440, %c1 : tensor<1x?x4096xi64> + %19441 = arith.index_cast %dim_6843 : index to i64 + %from_elements_6844 = tensor.from_elements %c1_i64, %19441, %c4096_i64, %c1_i64 : tensor<4xi64> + %19442 = stablehlo.dynamic_reshape %19440, %from_elements_6844 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19443 = stablehlo.concatenate %19436, %19439, %19442, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %19444 = "stablehlo.gather"(%19177, %19443) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %19445 = shape.shape_of %19444 : tensor<1x?x4096xf32> -> tensor<3xindex> + %19446 = shape.num_elements %19445 : tensor<3xindex> -> index + %19447 = stablehlo.compute_reshape_shape %19446, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %19448 = stablehlo.dynamic_reshape %19444, %19447 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %19449 = stablehlo.dot %19448, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %19450 = stablehlo.logistic %19449 : tensor + %19451 = shape.shape_of %19450 : tensor -> tensor<2xindex> + %19452 = shape.shape_of %19449 : tensor -> tensor<2xindex> + %19453 = shape.cstr_broadcastable %19451, %19452 : tensor<2xindex>, tensor<2xindex> + %19454 = shape.assuming %19453 -> (tensor) { + %19688 = shape.broadcast %19451, %19452 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19450, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19449, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19455 = shape.shape_of %19454 : tensor -> tensor<2xindex> + %19456 = shape.cstr_broadcastable %19455, %19452 : tensor<2xindex>, tensor<2xindex> + %19457 = shape.assuming %19456 -> (tensor) { + %19688 = shape.broadcast %19455, %19452 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19454, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19449, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19458 = stablehlo.dot %19457, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6845 = tensor.dim %19430, %c0 : tensor + %19459 = arith.index_cast %dim_6845 : index to i64 + %from_elements_6846 = tensor.from_elements %19459, %c1_i64 : tensor<2xi64> + %19460 = stablehlo.dynamic_reshape %19430, %from_elements_6846 : (tensor, tensor<2xi64>) -> tensor + %dim_6847 = tensor.dim %19427, %c0 : tensor + %19461 = arith.index_cast %dim_6847 : index to i64 + %from_elements_6848 = tensor.from_elements %19461, %c1_i64 : tensor<2xi64> + %19462 = stablehlo.dynamic_reshape %19427, %from_elements_6848 : (tensor, tensor<2xi64>) -> tensor + %19463 = stablehlo.concatenate %19460, %19462, dim = 1 : (tensor, tensor) -> tensor + %19464 = "stablehlo.gather"(%19206, %19463) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %19465 = shape.shape_of %19458 : tensor -> tensor<2xindex> + %19466 = shape.shape_of %19464 : tensor -> tensor<2xindex> + %19467 = shape.cstr_broadcastable %19465, %19466 : tensor<2xindex>, tensor<2xindex> + %19468 = shape.assuming %19467 -> (tensor) { + %19688 = shape.broadcast %19465, %19466 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19458, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19464, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19469 = shape.shape_of %19468 : tensor -> tensor<2xindex> + %19470 = stablehlo.dynamic_broadcast_in_dim %19468, %19469, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19471 = stablehlo.dynamic_broadcast_in_dim %213, %19469, dims = [] : (tensor, tensor<2xindex>) -> tensor + %19472 = stablehlo.multiply %19470, %19471 : tensor + %dim_6849 = tensor.dim %19432, %c0 : tensor + %19473 = arith.index_cast %dim_6849 : index to i64 + %dim_6850 = tensor.dim %19468, %c0 : tensor + %19474 = arith.index_cast %dim_6850 : index to i64 + %19475 = arith.maxsi %19473, %19474 : i64 + %19476 = arith.index_cast %19475 : i64 to index + %from_elements_6851 = tensor.from_elements %19476, %c4096 : tensor<2xindex> + %19477 = stablehlo.dynamic_broadcast_in_dim %19432, %from_elements_6851, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6852 = tensor.dim %19477, %c0 : tensor + %19478 = arith.index_cast %dim_6852 : index to i64 + %from_elements_6853 = tensor.from_elements %19478, %c4096_i64 : tensor<2xi64> + %19479 = stablehlo.real_dynamic_slice %19472, %c_22, %from_elements_6853, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6854 = tensor.from_elements %19478, %c4096_i64, %c1_i64 : tensor<3xi64> + %19480 = stablehlo.dynamic_reshape %19477, %from_elements_6854 : (tensor, tensor<3xi64>) -> tensor + %19481 = stablehlo.dynamic_iota %from_elements_6854, dim = 1 : (tensor<3xi64>) -> tensor + %19482 = stablehlo.concatenate %19480, %19481, dim = 2 : (tensor, tensor) -> tensor + %19483 = "stablehlo.scatter"(%19420, %19482, %19479) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %19484 = stablehlo.slice %19166 [5:6, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %19485 = stablehlo.reshape %19484 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %19486 = stablehlo.custom_call @byteir.non_zero(%19485) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6855 = tensor.dim %19486, %c0 : tensor + %19487 = arith.index_cast %dim_6855 : index to i64 + %from_elements_6856 = tensor.from_elements %19487, %c1_i64 : tensor<2xi64> + %19488 = stablehlo.real_dynamic_slice %19486, %c_22, %from_elements_6856, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6857 = tensor.dim %19488, %c0 : tensor + %19489 = arith.index_cast %dim_6857 : index to i64 + %from_elements_6858 = tensor.from_elements %19489 : tensor<1xi64> + %19490 = stablehlo.dynamic_reshape %19488, %from_elements_6858 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6859 = tensor.from_elements %19487, %c2_i64 : tensor<2xi64> + %19491 = stablehlo.real_dynamic_slice %19486, %c_24, %from_elements_6859, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6860 = tensor.dim %19491, %c0 : tensor + %19492 = arith.index_cast %dim_6860 : index to i64 + %from_elements_6861 = tensor.from_elements %19492 : tensor<1xi64> + %19493 = stablehlo.dynamic_reshape %19491, %from_elements_6861 : (tensor, tensor<1xi64>) -> tensor + %dim_6862 = tensor.dim %19493, %c0 : tensor + %19494 = arith.index_cast %dim_6862 : index to i64 + %from_elements_6863 = tensor.from_elements %19494, %c1_i64 : tensor<2xi64> + %19495 = stablehlo.dynamic_reshape %19493, %from_elements_6863 : (tensor, tensor<2xi64>) -> tensor + %dim_6864 = tensor.dim %19495, %c0 : tensor + %19496 = arith.index_cast %dim_6864 : index to i64 + %from_elements_6865 = tensor.from_elements %c1_i64, %19496, %c4096_i64 : tensor<3xi64> + %19497 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6865, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6866 = tensor.dim %19497, %c1 : tensor<1x?x4096xi64> + %19498 = arith.index_cast %dim_6866 : index to i64 + %from_elements_6867 = tensor.from_elements %c1_i64, %19498, %c4096_i64, %c1_i64 : tensor<4xi64> + %19499 = stablehlo.dynamic_reshape %19497, %from_elements_6867 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19500 = stablehlo.dynamic_broadcast_in_dim %19495, %from_elements_6865, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6868 = tensor.dim %19500, %c1 : tensor<1x?x4096xi64> + %19501 = arith.index_cast %dim_6868 : index to i64 + %from_elements_6869 = tensor.from_elements %c1_i64, %19501, %c4096_i64, %c1_i64 : tensor<4xi64> + %19502 = stablehlo.dynamic_reshape %19500, %from_elements_6869 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19503 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6865, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6870 = tensor.dim %19503, %c1 : tensor<1x?x4096xi64> + %19504 = arith.index_cast %dim_6870 : index to i64 + %from_elements_6871 = tensor.from_elements %c1_i64, %19504, %c4096_i64, %c1_i64 : tensor<4xi64> + %19505 = stablehlo.dynamic_reshape %19503, %from_elements_6871 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19506 = stablehlo.concatenate %19499, %19502, %19505, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %19507 = "stablehlo.gather"(%19177, %19506) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %19508 = shape.shape_of %19507 : tensor<1x?x4096xf32> -> tensor<3xindex> + %19509 = shape.num_elements %19508 : tensor<3xindex> -> index + %19510 = stablehlo.compute_reshape_shape %19509, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %19511 = stablehlo.dynamic_reshape %19507, %19510 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %19512 = stablehlo.dot %19511, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %19513 = stablehlo.logistic %19512 : tensor + %19514 = shape.shape_of %19513 : tensor -> tensor<2xindex> + %19515 = shape.shape_of %19512 : tensor -> tensor<2xindex> + %19516 = shape.cstr_broadcastable %19514, %19515 : tensor<2xindex>, tensor<2xindex> + %19517 = shape.assuming %19516 -> (tensor) { + %19688 = shape.broadcast %19514, %19515 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19513, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19512, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19518 = shape.shape_of %19517 : tensor -> tensor<2xindex> + %19519 = shape.cstr_broadcastable %19518, %19515 : tensor<2xindex>, tensor<2xindex> + %19520 = shape.assuming %19519 -> (tensor) { + %19688 = shape.broadcast %19518, %19515 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19517, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19512, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19521 = stablehlo.dot %19520, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6872 = tensor.dim %19493, %c0 : tensor + %19522 = arith.index_cast %dim_6872 : index to i64 + %from_elements_6873 = tensor.from_elements %19522, %c1_i64 : tensor<2xi64> + %19523 = stablehlo.dynamic_reshape %19493, %from_elements_6873 : (tensor, tensor<2xi64>) -> tensor + %dim_6874 = tensor.dim %19490, %c0 : tensor + %19524 = arith.index_cast %dim_6874 : index to i64 + %from_elements_6875 = tensor.from_elements %19524, %c1_i64 : tensor<2xi64> + %19525 = stablehlo.dynamic_reshape %19490, %from_elements_6875 : (tensor, tensor<2xi64>) -> tensor + %19526 = stablehlo.concatenate %19523, %19525, dim = 1 : (tensor, tensor) -> tensor + %19527 = "stablehlo.gather"(%19206, %19526) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %19528 = shape.shape_of %19521 : tensor -> tensor<2xindex> + %19529 = shape.shape_of %19527 : tensor -> tensor<2xindex> + %19530 = shape.cstr_broadcastable %19528, %19529 : tensor<2xindex>, tensor<2xindex> + %19531 = shape.assuming %19530 -> (tensor) { + %19688 = shape.broadcast %19528, %19529 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19521, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19527, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19532 = shape.shape_of %19531 : tensor -> tensor<2xindex> + %19533 = stablehlo.dynamic_broadcast_in_dim %19531, %19532, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19534 = stablehlo.dynamic_broadcast_in_dim %213, %19532, dims = [] : (tensor, tensor<2xindex>) -> tensor + %19535 = stablehlo.multiply %19533, %19534 : tensor + %dim_6876 = tensor.dim %19495, %c0 : tensor + %19536 = arith.index_cast %dim_6876 : index to i64 + %dim_6877 = tensor.dim %19531, %c0 : tensor + %19537 = arith.index_cast %dim_6877 : index to i64 + %19538 = arith.maxsi %19536, %19537 : i64 + %19539 = arith.index_cast %19538 : i64 to index + %from_elements_6878 = tensor.from_elements %19539, %c4096 : tensor<2xindex> + %19540 = stablehlo.dynamic_broadcast_in_dim %19495, %from_elements_6878, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6879 = tensor.dim %19540, %c0 : tensor + %19541 = arith.index_cast %dim_6879 : index to i64 + %from_elements_6880 = tensor.from_elements %19541, %c4096_i64 : tensor<2xi64> + %19542 = stablehlo.real_dynamic_slice %19535, %c_22, %from_elements_6880, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6881 = tensor.from_elements %19541, %c4096_i64, %c1_i64 : tensor<3xi64> + %19543 = stablehlo.dynamic_reshape %19540, %from_elements_6881 : (tensor, tensor<3xi64>) -> tensor + %19544 = stablehlo.dynamic_iota %from_elements_6881, dim = 1 : (tensor<3xi64>) -> tensor + %19545 = stablehlo.concatenate %19543, %19544, dim = 2 : (tensor, tensor) -> tensor + %19546 = "stablehlo.scatter"(%19483, %19545, %19542) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %19547 = stablehlo.slice %19166 [6:7, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %19548 = stablehlo.reshape %19547 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %19549 = stablehlo.custom_call @byteir.non_zero(%19548) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6882 = tensor.dim %19549, %c0 : tensor + %19550 = arith.index_cast %dim_6882 : index to i64 + %from_elements_6883 = tensor.from_elements %19550, %c1_i64 : tensor<2xi64> + %19551 = stablehlo.real_dynamic_slice %19549, %c_22, %from_elements_6883, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6884 = tensor.dim %19551, %c0 : tensor + %19552 = arith.index_cast %dim_6884 : index to i64 + %from_elements_6885 = tensor.from_elements %19552 : tensor<1xi64> + %19553 = stablehlo.dynamic_reshape %19551, %from_elements_6885 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6886 = tensor.from_elements %19550, %c2_i64 : tensor<2xi64> + %19554 = stablehlo.real_dynamic_slice %19549, %c_24, %from_elements_6886, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6887 = tensor.dim %19554, %c0 : tensor + %19555 = arith.index_cast %dim_6887 : index to i64 + %from_elements_6888 = tensor.from_elements %19555 : tensor<1xi64> + %19556 = stablehlo.dynamic_reshape %19554, %from_elements_6888 : (tensor, tensor<1xi64>) -> tensor + %dim_6889 = tensor.dim %19556, %c0 : tensor + %19557 = arith.index_cast %dim_6889 : index to i64 + %from_elements_6890 = tensor.from_elements %19557, %c1_i64 : tensor<2xi64> + %19558 = stablehlo.dynamic_reshape %19556, %from_elements_6890 : (tensor, tensor<2xi64>) -> tensor + %dim_6891 = tensor.dim %19558, %c0 : tensor + %19559 = arith.index_cast %dim_6891 : index to i64 + %from_elements_6892 = tensor.from_elements %c1_i64, %19559, %c4096_i64 : tensor<3xi64> + %19560 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6892, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6893 = tensor.dim %19560, %c1 : tensor<1x?x4096xi64> + %19561 = arith.index_cast %dim_6893 : index to i64 + %from_elements_6894 = tensor.from_elements %c1_i64, %19561, %c4096_i64, %c1_i64 : tensor<4xi64> + %19562 = stablehlo.dynamic_reshape %19560, %from_elements_6894 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19563 = stablehlo.dynamic_broadcast_in_dim %19558, %from_elements_6892, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6895 = tensor.dim %19563, %c1 : tensor<1x?x4096xi64> + %19564 = arith.index_cast %dim_6895 : index to i64 + %from_elements_6896 = tensor.from_elements %c1_i64, %19564, %c4096_i64, %c1_i64 : tensor<4xi64> + %19565 = stablehlo.dynamic_reshape %19563, %from_elements_6896 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19566 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6892, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6897 = tensor.dim %19566, %c1 : tensor<1x?x4096xi64> + %19567 = arith.index_cast %dim_6897 : index to i64 + %from_elements_6898 = tensor.from_elements %c1_i64, %19567, %c4096_i64, %c1_i64 : tensor<4xi64> + %19568 = stablehlo.dynamic_reshape %19566, %from_elements_6898 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19569 = stablehlo.concatenate %19562, %19565, %19568, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %19570 = "stablehlo.gather"(%19177, %19569) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %19571 = shape.shape_of %19570 : tensor<1x?x4096xf32> -> tensor<3xindex> + %19572 = shape.num_elements %19571 : tensor<3xindex> -> index + %19573 = stablehlo.compute_reshape_shape %19572, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %19574 = stablehlo.dynamic_reshape %19570, %19573 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %19575 = stablehlo.dot %19574, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %19576 = stablehlo.logistic %19575 : tensor + %19577 = shape.shape_of %19576 : tensor -> tensor<2xindex> + %19578 = shape.shape_of %19575 : tensor -> tensor<2xindex> + %19579 = shape.cstr_broadcastable %19577, %19578 : tensor<2xindex>, tensor<2xindex> + %19580 = shape.assuming %19579 -> (tensor) { + %19688 = shape.broadcast %19577, %19578 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19576, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19575, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19581 = shape.shape_of %19580 : tensor -> tensor<2xindex> + %19582 = shape.cstr_broadcastable %19581, %19578 : tensor<2xindex>, tensor<2xindex> + %19583 = shape.assuming %19582 -> (tensor) { + %19688 = shape.broadcast %19581, %19578 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19580, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19575, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19584 = stablehlo.dot %19583, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6899 = tensor.dim %19556, %c0 : tensor + %19585 = arith.index_cast %dim_6899 : index to i64 + %from_elements_6900 = tensor.from_elements %19585, %c1_i64 : tensor<2xi64> + %19586 = stablehlo.dynamic_reshape %19556, %from_elements_6900 : (tensor, tensor<2xi64>) -> tensor + %dim_6901 = tensor.dim %19553, %c0 : tensor + %19587 = arith.index_cast %dim_6901 : index to i64 + %from_elements_6902 = tensor.from_elements %19587, %c1_i64 : tensor<2xi64> + %19588 = stablehlo.dynamic_reshape %19553, %from_elements_6902 : (tensor, tensor<2xi64>) -> tensor + %19589 = stablehlo.concatenate %19586, %19588, dim = 1 : (tensor, tensor) -> tensor + %19590 = "stablehlo.gather"(%19206, %19589) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %19591 = shape.shape_of %19584 : tensor -> tensor<2xindex> + %19592 = shape.shape_of %19590 : tensor -> tensor<2xindex> + %19593 = shape.cstr_broadcastable %19591, %19592 : tensor<2xindex>, tensor<2xindex> + %19594 = shape.assuming %19593 -> (tensor) { + %19688 = shape.broadcast %19591, %19592 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19584, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19590, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19595 = shape.shape_of %19594 : tensor -> tensor<2xindex> + %19596 = stablehlo.dynamic_broadcast_in_dim %19594, %19595, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19597 = stablehlo.dynamic_broadcast_in_dim %213, %19595, dims = [] : (tensor, tensor<2xindex>) -> tensor + %19598 = stablehlo.multiply %19596, %19597 : tensor + %dim_6903 = tensor.dim %19558, %c0 : tensor + %19599 = arith.index_cast %dim_6903 : index to i64 + %dim_6904 = tensor.dim %19594, %c0 : tensor + %19600 = arith.index_cast %dim_6904 : index to i64 + %19601 = arith.maxsi %19599, %19600 : i64 + %19602 = arith.index_cast %19601 : i64 to index + %from_elements_6905 = tensor.from_elements %19602, %c4096 : tensor<2xindex> + %19603 = stablehlo.dynamic_broadcast_in_dim %19558, %from_elements_6905, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6906 = tensor.dim %19603, %c0 : tensor + %19604 = arith.index_cast %dim_6906 : index to i64 + %from_elements_6907 = tensor.from_elements %19604, %c4096_i64 : tensor<2xi64> + %19605 = stablehlo.real_dynamic_slice %19598, %c_22, %from_elements_6907, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6908 = tensor.from_elements %19604, %c4096_i64, %c1_i64 : tensor<3xi64> + %19606 = stablehlo.dynamic_reshape %19603, %from_elements_6908 : (tensor, tensor<3xi64>) -> tensor + %19607 = stablehlo.dynamic_iota %from_elements_6908, dim = 1 : (tensor<3xi64>) -> tensor + %19608 = stablehlo.concatenate %19606, %19607, dim = 2 : (tensor, tensor) -> tensor + %19609 = "stablehlo.scatter"(%19546, %19608, %19605) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %19610 = stablehlo.slice %19166 [7:8, 0:2, 0:3] : (tensor<8x2x3xi64>) -> tensor<1x2x3xi64> + %19611 = stablehlo.reshape %19610 : (tensor<1x2x3xi64>) -> tensor<2x3xi64> + %19612 = stablehlo.custom_call @byteir.non_zero(%19611) {byteir_attrs = {}} : (tensor<2x3xi64>) -> tensor + %dim_6909 = tensor.dim %19612, %c0 : tensor + %19613 = arith.index_cast %dim_6909 : index to i64 + %from_elements_6910 = tensor.from_elements %19613, %c1_i64 : tensor<2xi64> + %19614 = stablehlo.real_dynamic_slice %19612, %c_22, %from_elements_6910, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6911 = tensor.dim %19614, %c0 : tensor + %19615 = arith.index_cast %dim_6911 : index to i64 + %from_elements_6912 = tensor.from_elements %19615 : tensor<1xi64> + %19616 = stablehlo.dynamic_reshape %19614, %from_elements_6912 : (tensor, tensor<1xi64>) -> tensor + %from_elements_6913 = tensor.from_elements %19613, %c2_i64 : tensor<2xi64> + %19617 = stablehlo.real_dynamic_slice %19612, %c_24, %from_elements_6913, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_6914 = tensor.dim %19617, %c0 : tensor + %19618 = arith.index_cast %dim_6914 : index to i64 + %from_elements_6915 = tensor.from_elements %19618 : tensor<1xi64> + %19619 = stablehlo.dynamic_reshape %19617, %from_elements_6915 : (tensor, tensor<1xi64>) -> tensor + %dim_6916 = tensor.dim %19619, %c0 : tensor + %19620 = arith.index_cast %dim_6916 : index to i64 + %from_elements_6917 = tensor.from_elements %19620, %c1_i64 : tensor<2xi64> + %19621 = stablehlo.dynamic_reshape %19619, %from_elements_6917 : (tensor, tensor<2xi64>) -> tensor + %dim_6918 = tensor.dim %19621, %c0 : tensor + %19622 = arith.index_cast %dim_6918 : index to i64 + %from_elements_6919 = tensor.from_elements %c1_i64, %19622, %c4096_i64 : tensor<3xi64> + %19623 = stablehlo.dynamic_broadcast_in_dim %173, %from_elements_6919, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6920 = tensor.dim %19623, %c1 : tensor<1x?x4096xi64> + %19624 = arith.index_cast %dim_6920 : index to i64 + %from_elements_6921 = tensor.from_elements %c1_i64, %19624, %c4096_i64, %c1_i64 : tensor<4xi64> + %19625 = stablehlo.dynamic_reshape %19623, %from_elements_6921 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19626 = stablehlo.dynamic_broadcast_in_dim %19621, %from_elements_6919, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6922 = tensor.dim %19626, %c1 : tensor<1x?x4096xi64> + %19627 = arith.index_cast %dim_6922 : index to i64 + %from_elements_6923 = tensor.from_elements %c1_i64, %19627, %c4096_i64, %c1_i64 : tensor<4xi64> + %19628 = stablehlo.dynamic_reshape %19626, %from_elements_6923 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19629 = stablehlo.dynamic_broadcast_in_dim %160, %from_elements_6919, dims = [2] : (tensor<4096xi64>, tensor<3xi64>) -> tensor<1x?x4096xi64> + %dim_6924 = tensor.dim %19629, %c1 : tensor<1x?x4096xi64> + %19630 = arith.index_cast %dim_6924 : index to i64 + %from_elements_6925 = tensor.from_elements %c1_i64, %19630, %c4096_i64, %c1_i64 : tensor<4xi64> + %19631 = stablehlo.dynamic_reshape %19629, %from_elements_6925 : (tensor<1x?x4096xi64>, tensor<4xi64>) -> tensor<1x?x4096x1xi64> + %19632 = stablehlo.concatenate %19625, %19628, %19631, dim = 3 : (tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>, tensor<1x?x4096x1xi64>) -> tensor<1x?x4096x3xi64> + %19633 = "stablehlo.gather"(%19177, %19632) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x3x4096xf32>, tensor<1x?x4096x3xi64>) -> tensor<1x?x4096xf32> + %19634 = shape.shape_of %19633 : tensor<1x?x4096xf32> -> tensor<3xindex> + %19635 = shape.num_elements %19634 : tensor<3xindex> -> index + %19636 = stablehlo.compute_reshape_shape %19635, %c_23 : (index, tensor<2xi64>) -> tensor<2xi64> + %19637 = stablehlo.dynamic_reshape %19633, %19636 : (tensor<1x?x4096xf32>, tensor<2xi64>) -> tensor + %19638 = stablehlo.dot %19637, %190 : (tensor, tensor<4096x14336xf32>) -> tensor + %19639 = stablehlo.logistic %19638 : tensor + %19640 = shape.shape_of %19639 : tensor -> tensor<2xindex> + %19641 = shape.shape_of %19638 : tensor -> tensor<2xindex> + %19642 = shape.cstr_broadcastable %19640, %19641 : tensor<2xindex>, tensor<2xindex> + %19643 = shape.assuming %19642 -> (tensor) { + %19688 = shape.broadcast %19640, %19641 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19639, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19638, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19644 = shape.shape_of %19643 : tensor -> tensor<2xindex> + %19645 = shape.cstr_broadcastable %19644, %19641 : tensor<2xindex>, tensor<2xindex> + %19646 = shape.assuming %19645 -> (tensor) { + %19688 = shape.broadcast %19644, %19641 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19643, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19638, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19647 = stablehlo.dot %19646, %200 : (tensor, tensor<14336x4096xf32>) -> tensor + %dim_6926 = tensor.dim %19619, %c0 : tensor + %19648 = arith.index_cast %dim_6926 : index to i64 + %from_elements_6927 = tensor.from_elements %19648, %c1_i64 : tensor<2xi64> + %19649 = stablehlo.dynamic_reshape %19619, %from_elements_6927 : (tensor, tensor<2xi64>) -> tensor + %dim_6928 = tensor.dim %19616, %c0 : tensor + %19650 = arith.index_cast %dim_6928 : index to i64 + %from_elements_6929 = tensor.from_elements %19650, %c1_i64 : tensor<2xi64> + %19651 = stablehlo.dynamic_reshape %19616, %from_elements_6929 : (tensor, tensor<2xi64>) -> tensor + %19652 = stablehlo.concatenate %19649, %19651, dim = 1 : (tensor, tensor) -> tensor + %19653 = "stablehlo.gather"(%19206, %19652) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2x1xf32>, tensor) -> tensor + %19654 = shape.shape_of %19647 : tensor -> tensor<2xindex> + %19655 = shape.shape_of %19653 : tensor -> tensor<2xindex> + %19656 = shape.cstr_broadcastable %19654, %19655 : tensor<2xindex>, tensor<2xindex> + %19657 = shape.assuming %19656 -> (tensor) { + %19688 = shape.broadcast %19654, %19655 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %19689 = stablehlo.dynamic_broadcast_in_dim %19647, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19690 = stablehlo.dynamic_broadcast_in_dim %19653, %19688, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19691 = stablehlo.multiply %19689, %19690 : tensor + shape.assuming_yield %19691 : tensor + } + %19658 = shape.shape_of %19657 : tensor -> tensor<2xindex> + %19659 = stablehlo.dynamic_broadcast_in_dim %19657, %19658, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %19660 = stablehlo.dynamic_broadcast_in_dim %213, %19658, dims = [] : (tensor, tensor<2xindex>) -> tensor + %19661 = stablehlo.multiply %19659, %19660 : tensor + %dim_6930 = tensor.dim %19621, %c0 : tensor + %19662 = arith.index_cast %dim_6930 : index to i64 + %dim_6931 = tensor.dim %19657, %c0 : tensor + %19663 = arith.index_cast %dim_6931 : index to i64 + %19664 = arith.maxsi %19662, %19663 : i64 + %19665 = arith.index_cast %19664 : i64 to index + %from_elements_6932 = tensor.from_elements %19665, %c4096 : tensor<2xindex> + %19666 = stablehlo.dynamic_broadcast_in_dim %19621, %from_elements_6932, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_6933 = tensor.dim %19666, %c0 : tensor + %19667 = arith.index_cast %dim_6933 : index to i64 + %from_elements_6934 = tensor.from_elements %19667, %c4096_i64 : tensor<2xi64> + %19668 = stablehlo.real_dynamic_slice %19661, %c_22, %from_elements_6934, %c_21 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_6935 = tensor.from_elements %19667, %c4096_i64, %c1_i64 : tensor<3xi64> + %19669 = stablehlo.dynamic_reshape %19666, %from_elements_6935 : (tensor, tensor<3xi64>) -> tensor + %19670 = stablehlo.dynamic_iota %from_elements_6935, dim = 1 : (tensor<3xi64>) -> tensor + %19671 = stablehlo.concatenate %19669, %19670, dim = 2 : (tensor, tensor) -> tensor + %19672 = "stablehlo.scatter"(%19609, %19671, %19668) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg129: tensor, %arg130: tensor): + %19688 = stablehlo.add %arg129, %arg130 : tensor + stablehlo.return %19688 : tensor + }) : (tensor<3x4096xf32>, tensor, tensor) -> tensor<3x4096xf32> + %19673 = stablehlo.reshape %19672 : (tensor<3x4096xf32>) -> tensor<3x1x4096xf32> + %19674 = stablehlo.add %19139, %19673 : tensor<3x1x4096xf32> + %19675 = stablehlo.broadcast_in_dim %19674, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %19676 = stablehlo.power %19675, %15 : tensor<3x1x4096xf32> + %19677 = stablehlo.reduce(%19676 init: %cst_20) applies stablehlo.add across dimensions = [2] : (tensor<3x1x4096xf32>, tensor) -> tensor<3x1xf32> + %19678 = stablehlo.reshape %19677 : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + %19679 = stablehlo.broadcast_in_dim %19678, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %19680 = stablehlo.divide %19679, %21 : tensor<3x1x1xf32> + %19681 = stablehlo.broadcast_in_dim %19680, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x1xf32> + %19682 = stablehlo.add %19681, %25 : tensor<3x1x1xf32> + %19683 = stablehlo.rsqrt %19682 : tensor<3x1x1xf32> + %19684 = stablehlo.broadcast_in_dim %19683, dims = [0, 1, 2] : (tensor<3x1x1xf32>) -> tensor<3x1x4096xf32> + %19685 = stablehlo.multiply %19675, %19684 : tensor<3x1x4096xf32> + %19686 = stablehlo.broadcast_in_dim %19685, dims = [0, 1, 2] : (tensor<3x1x4096xf32>) -> tensor<3x1x4096xf32> + %19687 = stablehlo.multiply %19686, %31 : tensor<3x1x4096xf32> + return %19687, %73, %74, %722, %723, %1335, %1336, %1948, %1949, %2561, %2562, %3174, %3175, %3787, %3788, %4400, %4401, %5013, %5014, %5626, %5627, %6239, %6240, %6852, %6853, %7465, %7466, %8078, %8079, %8691, %8692, %9304, %9305, %9917, %9918, %10530, %10531, %11143, %11144, %11756, %11757, %12369, %12370, %12982, %12983, %13595, %13596, %14208, %14209, %14821, %14822, %15434, %15435, %16047, %16048, %16660, %16661, %17273, %17274, %17886, %17887, %18499, %18500, %19112, %19113 : tensor<3x1x4096xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32>, tensor<3x8x8x128xf32> + } +} diff --git a/frontends/torch-frontend/examples/inference/mixtral/mixtral_decoder.stablehlo.elide.mlir b/frontends/torch-frontend/examples/inference/mixtral/mixtral_decoder.stablehlo.elide.mlir new file mode 100644 index 000000000..9d6361bd6 --- /dev/null +++ b/frontends/torch-frontend/examples/inference/mixtral/mixtral_decoder.stablehlo.elide.mlir @@ -0,0 +1,1158 @@ +module { + func.func @main(%arg0: tensor<131072x2xf32>, %arg1: tensor<131072x2xf32>, %arg2: tensor<5x7x32xf32>) -> tensor<5x7x32xf32> { + %cst = stablehlo.constant dense<3.200000e+01> : tensor + %cst_0 = stablehlo.constant dense<8.000000e+00> : tensor + %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor + %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<35x32xf32> + %cst_3 = stablehlo.constant dense<1.000000e+00> : tensor<1xf32> + %cst_4 = stablehlo.constant dense<9.99999974E-6> : tensor<1xf32> + %cst_5 = stablehlo.constant dense<3.200000e+01> : tensor<1xf32> + %c = stablehlo.constant dense<0> : tensor<1xi64> + %c_6 = stablehlo.constant dense<1> : tensor<1xi64> + %c_7 = stablehlo.constant dense<0> : tensor<32xi64> + %c_8 = stablehlo.constant dense<1> : tensor<32xi64> + %c_9 = stablehlo.constant dense<0> : tensor<8xi64> + %c_10 = stablehlo.constant dense<1> : tensor<8xi64> + %cst_11 = stablehlo.constant dense_resource : tensor<32x14336xf32> + %cst_12 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_13 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_14 = stablehlo.constant dense_resource : tensor<32x14336xf32> + %cst_15 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_16 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_17 = stablehlo.constant dense_resource : tensor<32x14336xf32> + %cst_18 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_19 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_20 = stablehlo.constant dense_resource : tensor<32x14336xf32> + %cst_21 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_22 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_23 = stablehlo.constant dense_resource : tensor<32x14336xf32> + %cst_24 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_25 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_26 = stablehlo.constant dense_resource : tensor<32x14336xf32> + %cst_27 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_28 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_29 = stablehlo.constant dense_resource : tensor<32x14336xf32> + %cst_30 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_31 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_32 = stablehlo.constant dense_resource : tensor<32x14336xf32> + %cst_33 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_34 = stablehlo.constant dense_resource : tensor<14336x32xf32> + %cst_35 = stablehlo.constant dense_resource : tensor<8x32xf32> + %cst_36 = stablehlo.constant dense_resource : tensor<32xf32> + %cst_37 = stablehlo.constant dense_resource : tensor<32x32xf32> + %cst_38 = stablehlo.constant dense_resource : tensor<8x32xf32> + %cst_39 = stablehlo.constant dense_resource : tensor<8x32xf32> + %cst_40 = stablehlo.constant dense_resource : tensor<32x32xf32> + %cst_41 = stablehlo.constant dense_resource : tensor<32xf32> + %c1_i64 = arith.constant 1 : i64 + %cst_42 = stablehlo.constant dense<0.000000e+00> : tensor + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c_43 = stablehlo.constant dense<0> : tensor<2xi64> + %c_44 = stablehlo.constant dense<1> : tensor<2xi64> + %c_45 = stablehlo.constant dense<[-1, 32]> : tensor<2xi64> + %c2_i64 = arith.constant 2 : i64 + %c_46 = stablehlo.constant dense<[0, 1]> : tensor<2xi64> + %c32_i64 = arith.constant 32 : i64 + %c32 = arith.constant 32 : index + %cst_47 = stablehlo.constant dense<2.000000e+00> : tensor<1xf32> + %0 = stablehlo.reshape %cst_47 : (tensor<1xf32>) -> tensor + %1 = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2] : (tensor<5x7x32xf32>) -> tensor<5x7x32xf32> + %2 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor<5x7x32xf32> + %3 = stablehlo.power %1, %2 : tensor<5x7x32xf32> + %4 = stablehlo.reduce(%3 init: %cst_42) applies stablehlo.add across dimensions = [2] : (tensor<5x7x32xf32>, tensor) -> tensor<5x7xf32> + %5 = stablehlo.reshape %4 : (tensor<5x7xf32>) -> tensor<5x7x1xf32> + %6 = stablehlo.reshape %cst_5 : (tensor<1xf32>) -> tensor + %7 = stablehlo.broadcast_in_dim %5, dims = [0, 1, 2] : (tensor<5x7x1xf32>) -> tensor<5x7x1xf32> + %8 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<5x7x1xf32> + %9 = stablehlo.divide %7, %8 : tensor<5x7x1xf32> + %10 = stablehlo.reshape %cst_4 : (tensor<1xf32>) -> tensor + %11 = stablehlo.broadcast_in_dim %9, dims = [0, 1, 2] : (tensor<5x7x1xf32>) -> tensor<5x7x1xf32> + %12 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<5x7x1xf32> + %13 = stablehlo.add %11, %12 : tensor<5x7x1xf32> + %14 = stablehlo.rsqrt %13 : tensor<5x7x1xf32> + %15 = stablehlo.broadcast_in_dim %14, dims = [0, 1, 2] : (tensor<5x7x1xf32>) -> tensor<5x7x32xf32> + %16 = stablehlo.multiply %1, %15 : tensor<5x7x32xf32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2] : (tensor<5x7x32xf32>) -> tensor<5x7x32xf32> + %18 = stablehlo.broadcast_in_dim %cst_41, dims = [2] : (tensor<32xf32>) -> tensor<5x7x32xf32> + %19 = stablehlo.multiply %17, %18 : tensor<5x7x32xf32> + %20 = stablehlo.transpose %cst_40, dims = [1, 0] : (tensor<32x32xf32>) -> tensor<32x32xf32> + %21 = stablehlo.reshape %19 : (tensor<5x7x32xf32>) -> tensor<35x32xf32> + %22 = stablehlo.dot %21, %20 : (tensor<35x32xf32>, tensor<32x32xf32>) -> tensor<35x32xf32> + %23 = stablehlo.reshape %22 : (tensor<35x32xf32>) -> tensor<5x7x32xf32> + %24 = stablehlo.transpose %cst_39, dims = [1, 0] : (tensor<8x32xf32>) -> tensor<32x8xf32> + %25 = stablehlo.dot %21, %24 : (tensor<35x32xf32>, tensor<32x8xf32>) -> tensor<35x8xf32> + %26 = stablehlo.reshape %25 : (tensor<35x8xf32>) -> tensor<5x7x8xf32> + %27 = stablehlo.transpose %cst_38, dims = [1, 0] : (tensor<8x32xf32>) -> tensor<32x8xf32> + %28 = stablehlo.dot %21, %27 : (tensor<35x32xf32>, tensor<32x8xf32>) -> tensor<35x8xf32> + %29 = stablehlo.reshape %28 : (tensor<35x8xf32>) -> tensor<5x7x8xf32> + %30 = stablehlo.reshape %23 : (tensor<5x7x32xf32>) -> tensor<5x7x32x1xf32> + %31 = stablehlo.transpose %30, dims = [0, 2, 1, 3] : (tensor<5x7x32x1xf32>) -> tensor<5x32x7x1xf32> + %32 = stablehlo.reshape %26 : (tensor<5x7x8xf32>) -> tensor<5x7x8x1xf32> + %33 = stablehlo.transpose %32, dims = [0, 2, 1, 3] : (tensor<5x7x8x1xf32>) -> tensor<5x8x7x1xf32> + %34 = stablehlo.reshape %29 : (tensor<5x7x8xf32>) -> tensor<5x7x8x1xf32> + %35 = stablehlo.transpose %34, dims = [0, 2, 1, 3] : (tensor<5x7x8x1xf32>) -> tensor<5x8x7x1xf32> + %36 = stablehlo.slice %arg0 [0:7, 0:2] : (tensor<131072x2xf32>) -> tensor<7x2xf32> + %37 = stablehlo.slice %arg1 [0:7, 0:2] : (tensor<131072x2xf32>) -> tensor<7x2xf32> + %38 = stablehlo.reshape %36 : (tensor<7x2xf32>) -> tensor<1x7x2xf32> + %39 = stablehlo.reshape %38 : (tensor<1x7x2xf32>) -> tensor<1x1x7x2xf32> + %40 = stablehlo.reshape %37 : (tensor<7x2xf32>) -> tensor<1x7x2xf32> + %41 = stablehlo.reshape %40 : (tensor<1x7x2xf32>) -> tensor<1x1x7x2xf32> + %42 = stablehlo.broadcast_in_dim %31, dims = [0, 1, 2, 3] : (tensor<5x32x7x1xf32>) -> tensor<5x32x7x2xf32> + %43 = stablehlo.broadcast_in_dim %39, dims = [0, 1, 2, 3] : (tensor<1x1x7x2xf32>) -> tensor<5x32x7x2xf32> + %44 = stablehlo.multiply %42, %43 : tensor<5x32x7x2xf32> + %45 = stablehlo.negate %31 : tensor<5x32x7x1xf32> + %46 = stablehlo.broadcast_in_dim %45, dims = [0, 1, 2, 3] : (tensor<5x32x7x1xf32>) -> tensor<5x32x7x2xf32> + %47 = stablehlo.broadcast_in_dim %41, dims = [0, 1, 2, 3] : (tensor<1x1x7x2xf32>) -> tensor<5x32x7x2xf32> + %48 = stablehlo.multiply %46, %47 : tensor<5x32x7x2xf32> + %49 = stablehlo.add %44, %48 : tensor<5x32x7x2xf32> + %50 = stablehlo.broadcast_in_dim %33, dims = [0, 1, 2, 3] : (tensor<5x8x7x1xf32>) -> tensor<5x8x7x2xf32> + %51 = stablehlo.broadcast_in_dim %39, dims = [0, 1, 2, 3] : (tensor<1x1x7x2xf32>) -> tensor<5x8x7x2xf32> + %52 = stablehlo.multiply %50, %51 : tensor<5x8x7x2xf32> + %53 = stablehlo.negate %33 : tensor<5x8x7x1xf32> + %54 = stablehlo.broadcast_in_dim %53, dims = [0, 1, 2, 3] : (tensor<5x8x7x1xf32>) -> tensor<5x8x7x2xf32> + %55 = stablehlo.broadcast_in_dim %41, dims = [0, 1, 2, 3] : (tensor<1x1x7x2xf32>) -> tensor<5x8x7x2xf32> + %56 = stablehlo.multiply %54, %55 : tensor<5x8x7x2xf32> + %57 = stablehlo.add %52, %56 : tensor<5x8x7x2xf32> + %58 = stablehlo.reshape %57 : (tensor<5x8x7x2xf32>) -> tensor<5x8x1x7x2xf32> + %59 = stablehlo.broadcast_in_dim %58, dims = [0, 1, 2, 3, 4] : (tensor<5x8x1x7x2xf32>) -> tensor<5x8x4x7x2xf32> + %60 = stablehlo.reshape %59 : (tensor<5x8x4x7x2xf32>) -> tensor<5x32x7x2xf32> + %61 = stablehlo.reshape %35 : (tensor<5x8x7x1xf32>) -> tensor<5x8x1x7x1xf32> + %62 = stablehlo.broadcast_in_dim %61, dims = [0, 1, 2, 3, 4] : (tensor<5x8x1x7x1xf32>) -> tensor<5x8x4x7x1xf32> + %63 = stablehlo.reshape %62 : (tensor<5x8x4x7x1xf32>) -> tensor<5x32x7x1xf32> + %64 = stablehlo.transpose %60, dims = [0, 1, 3, 2] : (tensor<5x32x7x2xf32>) -> tensor<5x32x2x7xf32> + %65 = stablehlo.reshape %49 : (tensor<5x32x7x2xf32>) -> tensor<160x7x2xf32> + %66 = stablehlo.reshape %64 : (tensor<5x32x2x7xf32>) -> tensor<160x2x7xf32> + %67 = stablehlo.broadcast_in_dim %66, dims = [0, 1, 2] : (tensor<160x2x7xf32>) -> tensor<160x2x7xf32> + %68 = stablehlo.dot_general %65, %67, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<160x7x2xf32>, tensor<160x2x7xf32>) -> tensor<160x7x7xf32> + %69 = stablehlo.reshape %68 : (tensor<160x7x7xf32>) -> tensor<5x32x7x7xf32> + %70 = stablehlo.reshape %cst_3 : (tensor<1xf32>) -> tensor + %71 = stablehlo.broadcast_in_dim %69, dims = [0, 1, 2, 3] : (tensor<5x32x7x7xf32>) -> tensor<5x32x7x7xf32> + %72 = stablehlo.broadcast_in_dim %70, dims = [] : (tensor) -> tensor<5x32x7x7xf32> + %73 = stablehlo.divide %71, %72 : tensor<5x32x7x7xf32> + %74 = stablehlo.custom_call @byteir.softmax(%73) {byteir_attrs = {axis = 3 : i64}} : (tensor<5x32x7x7xf32>) -> tensor<5x32x7x7xf32> + %75 = stablehlo.reshape %74 : (tensor<5x32x7x7xf32>) -> tensor<160x7x7xf32> + %76 = stablehlo.reshape %63 : (tensor<5x32x7x1xf32>) -> tensor<160x7x1xf32> + %77 = stablehlo.broadcast_in_dim %76, dims = [0, 1, 2] : (tensor<160x7x1xf32>) -> tensor<160x7x1xf32> + %78 = stablehlo.dot_general %75, %77, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<160x7x7xf32>, tensor<160x7x1xf32>) -> tensor<160x7x1xf32> + %79 = stablehlo.reshape %78 : (tensor<160x7x1xf32>) -> tensor<5x32x7x1xf32> + %80 = stablehlo.transpose %79, dims = [0, 2, 1, 3] : (tensor<5x32x7x1xf32>) -> tensor<5x7x32x1xf32> + %81 = stablehlo.reshape %80 : (tensor<5x7x32x1xf32>) -> tensor<5x7x32xf32> + %82 = stablehlo.transpose %cst_37, dims = [1, 0] : (tensor<32x32xf32>) -> tensor<32x32xf32> + %83 = stablehlo.reshape %81 : (tensor<5x7x32xf32>) -> tensor<35x32xf32> + %84 = stablehlo.dot %83, %82 : (tensor<35x32xf32>, tensor<32x32xf32>) -> tensor<35x32xf32> + %85 = stablehlo.reshape %84 : (tensor<35x32xf32>) -> tensor<5x7x32xf32> + %86 = stablehlo.add %arg2, %85 : tensor<5x7x32xf32> + %87 = stablehlo.broadcast_in_dim %86, dims = [0, 1, 2] : (tensor<5x7x32xf32>) -> tensor<5x7x32xf32> + %88 = stablehlo.power %87, %2 : tensor<5x7x32xf32> + %89 = stablehlo.reduce(%88 init: %cst_42) applies stablehlo.add across dimensions = [2] : (tensor<5x7x32xf32>, tensor) -> tensor<5x7xf32> + %90 = stablehlo.reshape %89 : (tensor<5x7xf32>) -> tensor<5x7x1xf32> + %91 = stablehlo.broadcast_in_dim %90, dims = [0, 1, 2] : (tensor<5x7x1xf32>) -> tensor<5x7x1xf32> + %92 = stablehlo.divide %91, %8 : tensor<5x7x1xf32> + %93 = stablehlo.broadcast_in_dim %92, dims = [0, 1, 2] : (tensor<5x7x1xf32>) -> tensor<5x7x1xf32> + %94 = stablehlo.add %93, %12 : tensor<5x7x1xf32> + %95 = stablehlo.rsqrt %94 : tensor<5x7x1xf32> + %96 = stablehlo.broadcast_in_dim %95, dims = [0, 1, 2] : (tensor<5x7x1xf32>) -> tensor<5x7x32xf32> + %97 = stablehlo.multiply %87, %96 : tensor<5x7x32xf32> + %98 = stablehlo.broadcast_in_dim %97, dims = [0, 1, 2] : (tensor<5x7x32xf32>) -> tensor<5x7x32xf32> + %99 = stablehlo.broadcast_in_dim %cst_36, dims = [2] : (tensor<32xf32>) -> tensor<5x7x32xf32> + %100 = stablehlo.multiply %98, %99 : tensor<5x7x32xf32> + %101 = stablehlo.reshape %100 : (tensor<5x7x32xf32>) -> tensor<35x32xf32> + %102 = stablehlo.transpose %cst_35, dims = [1, 0] : (tensor<8x32xf32>) -> tensor<32x8xf32> + %103 = stablehlo.dot %101, %102 : (tensor<35x32xf32>, tensor<32x8xf32>) -> tensor<35x8xf32> + %104 = stablehlo.custom_call @byteir.softmax(%103) {byteir_attrs = {axis = 1 : i64}} : (tensor<35x8xf32>) -> tensor<35x8xf32> + %105:2 = stablehlo.custom_call @byteir.top_k(%104) {byteir_attrs = {axis = [1], k = 2 : i64, sorted = true}} : (tensor<35x8xf32>) -> (tensor<35x2xf32>, tensor<35x2xi64>) + %106 = stablehlo.reduce(%105#0 init: %cst_42) applies stablehlo.add across dimensions = [1] : (tensor<35x2xf32>, tensor) -> tensor<35xf32> + %107 = stablehlo.reshape %106 : (tensor<35xf32>) -> tensor<35x1xf32> + %108 = stablehlo.broadcast_in_dim %105#0, dims = [0, 1] : (tensor<35x2xf32>) -> tensor<35x2xf32> + %109 = stablehlo.broadcast_in_dim %107, dims = [0, 1] : (tensor<35x1xf32>) -> tensor<35x2xf32> + %110 = stablehlo.divide %108, %109 : tensor<35x2xf32> + %111 = stablehlo.divide %cst_0, %cst_1 : tensor + %112 = stablehlo.ceil %111 : tensor + %113 = stablehlo.convert %112 : (tensor) -> tensor + %114 = stablehlo.reshape %113 : (tensor) -> tensor<1xi64> + %115 = stablehlo.dynamic_iota %114, dim = 0 : (tensor<1xi64>) -> tensor<8xi64> + %116 = stablehlo.broadcast_in_dim %115, dims = [0] : (tensor<8xi64>) -> tensor<8xi64> + %117 = stablehlo.multiply %116, %c_10 : tensor<8xi64> + %118 = stablehlo.broadcast_in_dim %117, dims = [0] : (tensor<8xi64>) -> tensor<8xi64> + %119 = stablehlo.add %118, %c_9 : tensor<8xi64> + %120 = stablehlo.reshape %105#1 : (tensor<35x2xi64>) -> tensor<35x2x1xi64> + %121 = stablehlo.broadcast_in_dim %120, dims = [0, 1, 2] : (tensor<35x2x1xi64>) -> tensor<35x2x8xi64> + %122 = stablehlo.broadcast_in_dim %119, dims = [2] : (tensor<8xi64>) -> tensor<35x2x8xi64> + %123 = stablehlo.compare EQ, %121, %122, SIGNED : (tensor<35x2x8xi64>, tensor<35x2x8xi64>) -> tensor<35x2x8xi1> + %124 = stablehlo.convert %123 : (tensor<35x2x8xi1>) -> tensor<35x2x8xi64> + %125 = stablehlo.transpose %124, dims = [2, 1, 0] : (tensor<35x2x8xi64>) -> tensor<8x2x35xi64> + %126 = stablehlo.slice %125 [0:1, 0:2, 0:35] : (tensor<8x2x35xi64>) -> tensor<1x2x35xi64> + %127 = stablehlo.reshape %126 : (tensor<1x2x35xi64>) -> tensor<2x35xi64> + %128 = stablehlo.custom_call @byteir.non_zero(%127) {byteir_attrs = {}} : (tensor<2x35xi64>) -> tensor + %dim = tensor.dim %128, %c0 : tensor + %129 = arith.index_cast %dim : index to i64 + %from_elements = tensor.from_elements %129, %c1_i64 : tensor<2xi64> + %130 = stablehlo.real_dynamic_slice %128, %c_43, %from_elements, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_48 = tensor.dim %130, %c0 : tensor + %131 = arith.index_cast %dim_48 : index to i64 + %from_elements_49 = tensor.from_elements %131 : tensor<1xi64> + %132 = stablehlo.dynamic_reshape %130, %from_elements_49 : (tensor, tensor<1xi64>) -> tensor + %from_elements_50 = tensor.from_elements %129, %c2_i64 : tensor<2xi64> + %133 = stablehlo.real_dynamic_slice %128, %c_46, %from_elements_50, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_51 = tensor.dim %133, %c0 : tensor + %134 = arith.index_cast %dim_51 : index to i64 + %from_elements_52 = tensor.from_elements %134 : tensor<1xi64> + %135 = stablehlo.dynamic_reshape %133, %from_elements_52 : (tensor, tensor<1xi64>) -> tensor + %136 = stablehlo.reshape %101 : (tensor<35x32xf32>) -> tensor<1x35x32xf32> + %137 = stablehlo.divide %cst, %cst_1 : tensor + %138 = stablehlo.ceil %137 : tensor + %139 = stablehlo.convert %138 : (tensor) -> tensor + %140 = stablehlo.reshape %139 : (tensor) -> tensor<1xi64> + %141 = stablehlo.dynamic_iota %140, dim = 0 : (tensor<1xi64>) -> tensor<32xi64> + %142 = stablehlo.broadcast_in_dim %141, dims = [0] : (tensor<32xi64>) -> tensor<32xi64> + %143 = stablehlo.multiply %142, %c_8 : tensor<32xi64> + %144 = stablehlo.broadcast_in_dim %143, dims = [0] : (tensor<32xi64>) -> tensor<32xi64> + %145 = stablehlo.add %144, %c_7 : tensor<32xi64> + %dim_53 = tensor.dim %135, %c0 : tensor + %146 = arith.index_cast %dim_53 : index to i64 + %from_elements_54 = tensor.from_elements %146, %c1_i64 : tensor<2xi64> + %147 = stablehlo.dynamic_reshape %135, %from_elements_54 : (tensor, tensor<2xi64>) -> tensor + %148 = stablehlo.divide %cst_1, %cst_1 : tensor + %149 = stablehlo.ceil %148 : tensor + %150 = stablehlo.convert %149 : (tensor) -> tensor + %151 = stablehlo.reshape %150 : (tensor) -> tensor<1xi64> + %152 = stablehlo.dynamic_iota %151, dim = 0 : (tensor<1xi64>) -> tensor<1xi64> + %153 = stablehlo.broadcast_in_dim %152, dims = [0] : (tensor<1xi64>) -> tensor<1xi64> + %154 = stablehlo.multiply %153, %c_6 : tensor<1xi64> + %155 = stablehlo.broadcast_in_dim %154, dims = [0] : (tensor<1xi64>) -> tensor<1xi64> + %156 = stablehlo.add %155, %c : tensor<1xi64> + %157 = stablehlo.reshape %156 : (tensor<1xi64>) -> tensor<1x1xi64> + %158 = stablehlo.reshape %157 : (tensor<1x1xi64>) -> tensor<1x1x1xi64> + %dim_55 = tensor.dim %147, %c0 : tensor + %159 = arith.index_cast %dim_55 : index to i64 + %from_elements_56 = tensor.from_elements %c1_i64, %159, %c32_i64 : tensor<3xi64> + %160 = stablehlo.dynamic_broadcast_in_dim %158, %from_elements_56, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_57 = tensor.dim %160, %c1 : tensor<1x?x32xi64> + %161 = arith.index_cast %dim_57 : index to i64 + %from_elements_58 = tensor.from_elements %c1_i64, %161, %c32_i64, %c1_i64 : tensor<4xi64> + %162 = stablehlo.dynamic_reshape %160, %from_elements_58 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %163 = stablehlo.dynamic_broadcast_in_dim %147, %from_elements_56, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_59 = tensor.dim %163, %c1 : tensor<1x?x32xi64> + %164 = arith.index_cast %dim_59 : index to i64 + %from_elements_60 = tensor.from_elements %c1_i64, %164, %c32_i64, %c1_i64 : tensor<4xi64> + %165 = stablehlo.dynamic_reshape %163, %from_elements_60 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %166 = stablehlo.dynamic_broadcast_in_dim %145, %from_elements_56, dims = [2] : (tensor<32xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_61 = tensor.dim %166, %c1 : tensor<1x?x32xi64> + %167 = arith.index_cast %dim_61 : index to i64 + %from_elements_62 = tensor.from_elements %c1_i64, %167, %c32_i64, %c1_i64 : tensor<4xi64> + %168 = stablehlo.dynamic_reshape %166, %from_elements_62 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %169 = stablehlo.concatenate %162, %165, %168, dim = 3 : (tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>) -> tensor<1x?x32x3xi64> + %170 = "stablehlo.gather"(%136, %169) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x35x32xf32>, tensor<1x?x32x3xi64>) -> tensor<1x?x32xf32> + %171 = shape.shape_of %170 : tensor<1x?x32xf32> -> tensor<3xindex> + %172 = shape.num_elements %171 : tensor<3xindex> -> index + %173 = stablehlo.compute_reshape_shape %172, %c_45 : (index, tensor<2xi64>) -> tensor<2xi64> + %174 = stablehlo.dynamic_reshape %170, %173 : (tensor<1x?x32xf32>, tensor<2xi64>) -> tensor + %175 = stablehlo.transpose %cst_34, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %176 = stablehlo.dot %174, %175 : (tensor, tensor<32x14336xf32>) -> tensor + %177 = stablehlo.logistic %176 : tensor + %178 = shape.shape_of %177 : tensor -> tensor<2xindex> + %179 = shape.shape_of %176 : tensor -> tensor<2xindex> + %180 = shape.cstr_broadcastable %178, %179 : tensor<2xindex>, tensor<2xindex> + %181 = shape.assuming %180 -> (tensor) { + %695 = shape.broadcast %178, %179 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %177, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %176, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %182 = stablehlo.transpose %cst_33, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %183 = stablehlo.dot %174, %182 : (tensor, tensor<32x14336xf32>) -> tensor + %184 = shape.shape_of %181 : tensor -> tensor<2xindex> + %185 = shape.shape_of %183 : tensor -> tensor<2xindex> + %186 = shape.cstr_broadcastable %184, %185 : tensor<2xindex>, tensor<2xindex> + %187 = shape.assuming %186 -> (tensor) { + %695 = shape.broadcast %184, %185 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %181, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %183, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %188 = stablehlo.transpose %cst_32, dims = [1, 0] : (tensor<32x14336xf32>) -> tensor<14336x32xf32> + %189 = stablehlo.dot %187, %188 : (tensor, tensor<14336x32xf32>) -> tensor + %190 = stablehlo.reshape %110 : (tensor<35x2xf32>) -> tensor<35x2x1xf32> + %dim_63 = tensor.dim %135, %c0 : tensor + %191 = arith.index_cast %dim_63 : index to i64 + %from_elements_64 = tensor.from_elements %191, %c1_i64 : tensor<2xi64> + %192 = stablehlo.dynamic_reshape %135, %from_elements_64 : (tensor, tensor<2xi64>) -> tensor + %dim_65 = tensor.dim %132, %c0 : tensor + %193 = arith.index_cast %dim_65 : index to i64 + %from_elements_66 = tensor.from_elements %193, %c1_i64 : tensor<2xi64> + %194 = stablehlo.dynamic_reshape %132, %from_elements_66 : (tensor, tensor<2xi64>) -> tensor + %195 = stablehlo.concatenate %192, %194, dim = 1 : (tensor, tensor) -> tensor + %196 = "stablehlo.gather"(%190, %195) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<35x2x1xf32>, tensor) -> tensor + %197 = shape.shape_of %189 : tensor -> tensor<2xindex> + %198 = shape.shape_of %196 : tensor -> tensor<2xindex> + %199 = shape.cstr_broadcastable %197, %198 : tensor<2xindex>, tensor<2xindex> + %200 = shape.assuming %199 -> (tensor) { + %695 = shape.broadcast %197, %198 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %189, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %196, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %201 = stablehlo.reshape %cst_3 : (tensor<1xf32>) -> tensor + %202 = shape.shape_of %200 : tensor -> tensor<2xindex> + %203 = stablehlo.dynamic_broadcast_in_dim %200, %202, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %204 = stablehlo.dynamic_broadcast_in_dim %201, %202, dims = [] : (tensor, tensor<2xindex>) -> tensor + %205 = stablehlo.multiply %203, %204 : tensor + %dim_67 = tensor.dim %147, %c0 : tensor + %206 = arith.index_cast %dim_67 : index to i64 + %dim_68 = tensor.dim %200, %c0 : tensor + %207 = arith.index_cast %dim_68 : index to i64 + %208 = arith.maxsi %206, %207 : i64 + %209 = arith.index_cast %208 : i64 to index + %from_elements_69 = tensor.from_elements %209, %c32 : tensor<2xindex> + %210 = stablehlo.dynamic_broadcast_in_dim %147, %from_elements_69, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_70 = tensor.dim %210, %c0 : tensor + %211 = arith.index_cast %dim_70 : index to i64 + %from_elements_71 = tensor.from_elements %211, %c32_i64 : tensor<2xi64> + %212 = stablehlo.real_dynamic_slice %205, %c_43, %from_elements_71, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_72 = tensor.from_elements %211, %c32_i64, %c1_i64 : tensor<3xi64> + %213 = stablehlo.dynamic_reshape %210, %from_elements_72 : (tensor, tensor<3xi64>) -> tensor + %214 = stablehlo.dynamic_iota %from_elements_72, dim = 1 : (tensor<3xi64>) -> tensor + %215 = stablehlo.concatenate %213, %214, dim = 2 : (tensor, tensor) -> tensor + %216 = "stablehlo.scatter"(%cst_2, %215, %212) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %695 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %695 : tensor + }) : (tensor<35x32xf32>, tensor, tensor) -> tensor<35x32xf32> + %217 = stablehlo.slice %125 [1:2, 0:2, 0:35] : (tensor<8x2x35xi64>) -> tensor<1x2x35xi64> + %218 = stablehlo.reshape %217 : (tensor<1x2x35xi64>) -> tensor<2x35xi64> + %219 = stablehlo.custom_call @byteir.non_zero(%218) {byteir_attrs = {}} : (tensor<2x35xi64>) -> tensor + %dim_73 = tensor.dim %219, %c0 : tensor + %220 = arith.index_cast %dim_73 : index to i64 + %from_elements_74 = tensor.from_elements %220, %c1_i64 : tensor<2xi64> + %221 = stablehlo.real_dynamic_slice %219, %c_43, %from_elements_74, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_75 = tensor.dim %221, %c0 : tensor + %222 = arith.index_cast %dim_75 : index to i64 + %from_elements_76 = tensor.from_elements %222 : tensor<1xi64> + %223 = stablehlo.dynamic_reshape %221, %from_elements_76 : (tensor, tensor<1xi64>) -> tensor + %from_elements_77 = tensor.from_elements %220, %c2_i64 : tensor<2xi64> + %224 = stablehlo.real_dynamic_slice %219, %c_46, %from_elements_77, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_78 = tensor.dim %224, %c0 : tensor + %225 = arith.index_cast %dim_78 : index to i64 + %from_elements_79 = tensor.from_elements %225 : tensor<1xi64> + %226 = stablehlo.dynamic_reshape %224, %from_elements_79 : (tensor, tensor<1xi64>) -> tensor + %dim_80 = tensor.dim %226, %c0 : tensor + %227 = arith.index_cast %dim_80 : index to i64 + %from_elements_81 = tensor.from_elements %227, %c1_i64 : tensor<2xi64> + %228 = stablehlo.dynamic_reshape %226, %from_elements_81 : (tensor, tensor<2xi64>) -> tensor + %dim_82 = tensor.dim %228, %c0 : tensor + %229 = arith.index_cast %dim_82 : index to i64 + %from_elements_83 = tensor.from_elements %c1_i64, %229, %c32_i64 : tensor<3xi64> + %230 = stablehlo.dynamic_broadcast_in_dim %158, %from_elements_83, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_84 = tensor.dim %230, %c1 : tensor<1x?x32xi64> + %231 = arith.index_cast %dim_84 : index to i64 + %from_elements_85 = tensor.from_elements %c1_i64, %231, %c32_i64, %c1_i64 : tensor<4xi64> + %232 = stablehlo.dynamic_reshape %230, %from_elements_85 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %233 = stablehlo.dynamic_broadcast_in_dim %228, %from_elements_83, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_86 = tensor.dim %233, %c1 : tensor<1x?x32xi64> + %234 = arith.index_cast %dim_86 : index to i64 + %from_elements_87 = tensor.from_elements %c1_i64, %234, %c32_i64, %c1_i64 : tensor<4xi64> + %235 = stablehlo.dynamic_reshape %233, %from_elements_87 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %236 = stablehlo.dynamic_broadcast_in_dim %145, %from_elements_83, dims = [2] : (tensor<32xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_88 = tensor.dim %236, %c1 : tensor<1x?x32xi64> + %237 = arith.index_cast %dim_88 : index to i64 + %from_elements_89 = tensor.from_elements %c1_i64, %237, %c32_i64, %c1_i64 : tensor<4xi64> + %238 = stablehlo.dynamic_reshape %236, %from_elements_89 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %239 = stablehlo.concatenate %232, %235, %238, dim = 3 : (tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>) -> tensor<1x?x32x3xi64> + %240 = "stablehlo.gather"(%136, %239) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x35x32xf32>, tensor<1x?x32x3xi64>) -> tensor<1x?x32xf32> + %241 = shape.shape_of %240 : tensor<1x?x32xf32> -> tensor<3xindex> + %242 = shape.num_elements %241 : tensor<3xindex> -> index + %243 = stablehlo.compute_reshape_shape %242, %c_45 : (index, tensor<2xi64>) -> tensor<2xi64> + %244 = stablehlo.dynamic_reshape %240, %243 : (tensor<1x?x32xf32>, tensor<2xi64>) -> tensor + %245 = stablehlo.transpose %cst_31, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %246 = stablehlo.dot %244, %245 : (tensor, tensor<32x14336xf32>) -> tensor + %247 = stablehlo.logistic %246 : tensor + %248 = shape.shape_of %247 : tensor -> tensor<2xindex> + %249 = shape.shape_of %246 : tensor -> tensor<2xindex> + %250 = shape.cstr_broadcastable %248, %249 : tensor<2xindex>, tensor<2xindex> + %251 = shape.assuming %250 -> (tensor) { + %695 = shape.broadcast %248, %249 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %247, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %246, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %252 = stablehlo.transpose %cst_30, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %253 = stablehlo.dot %244, %252 : (tensor, tensor<32x14336xf32>) -> tensor + %254 = shape.shape_of %251 : tensor -> tensor<2xindex> + %255 = shape.shape_of %253 : tensor -> tensor<2xindex> + %256 = shape.cstr_broadcastable %254, %255 : tensor<2xindex>, tensor<2xindex> + %257 = shape.assuming %256 -> (tensor) { + %695 = shape.broadcast %254, %255 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %251, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %253, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %258 = stablehlo.transpose %cst_29, dims = [1, 0] : (tensor<32x14336xf32>) -> tensor<14336x32xf32> + %259 = stablehlo.dot %257, %258 : (tensor, tensor<14336x32xf32>) -> tensor + %dim_90 = tensor.dim %226, %c0 : tensor + %260 = arith.index_cast %dim_90 : index to i64 + %from_elements_91 = tensor.from_elements %260, %c1_i64 : tensor<2xi64> + %261 = stablehlo.dynamic_reshape %226, %from_elements_91 : (tensor, tensor<2xi64>) -> tensor + %dim_92 = tensor.dim %223, %c0 : tensor + %262 = arith.index_cast %dim_92 : index to i64 + %from_elements_93 = tensor.from_elements %262, %c1_i64 : tensor<2xi64> + %263 = stablehlo.dynamic_reshape %223, %from_elements_93 : (tensor, tensor<2xi64>) -> tensor + %264 = stablehlo.concatenate %261, %263, dim = 1 : (tensor, tensor) -> tensor + %265 = "stablehlo.gather"(%190, %264) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<35x2x1xf32>, tensor) -> tensor + %266 = shape.shape_of %259 : tensor -> tensor<2xindex> + %267 = shape.shape_of %265 : tensor -> tensor<2xindex> + %268 = shape.cstr_broadcastable %266, %267 : tensor<2xindex>, tensor<2xindex> + %269 = shape.assuming %268 -> (tensor) { + %695 = shape.broadcast %266, %267 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %259, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %265, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %270 = shape.shape_of %269 : tensor -> tensor<2xindex> + %271 = stablehlo.dynamic_broadcast_in_dim %269, %270, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %272 = stablehlo.dynamic_broadcast_in_dim %201, %270, dims = [] : (tensor, tensor<2xindex>) -> tensor + %273 = stablehlo.multiply %271, %272 : tensor + %dim_94 = tensor.dim %228, %c0 : tensor + %274 = arith.index_cast %dim_94 : index to i64 + %dim_95 = tensor.dim %269, %c0 : tensor + %275 = arith.index_cast %dim_95 : index to i64 + %276 = arith.maxsi %274, %275 : i64 + %277 = arith.index_cast %276 : i64 to index + %from_elements_96 = tensor.from_elements %277, %c32 : tensor<2xindex> + %278 = stablehlo.dynamic_broadcast_in_dim %228, %from_elements_96, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_97 = tensor.dim %278, %c0 : tensor + %279 = arith.index_cast %dim_97 : index to i64 + %from_elements_98 = tensor.from_elements %279, %c32_i64 : tensor<2xi64> + %280 = stablehlo.real_dynamic_slice %273, %c_43, %from_elements_98, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_99 = tensor.from_elements %279, %c32_i64, %c1_i64 : tensor<3xi64> + %281 = stablehlo.dynamic_reshape %278, %from_elements_99 : (tensor, tensor<3xi64>) -> tensor + %282 = stablehlo.dynamic_iota %from_elements_99, dim = 1 : (tensor<3xi64>) -> tensor + %283 = stablehlo.concatenate %281, %282, dim = 2 : (tensor, tensor) -> tensor + %284 = "stablehlo.scatter"(%216, %283, %280) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %695 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %695 : tensor + }) : (tensor<35x32xf32>, tensor, tensor) -> tensor<35x32xf32> + %285 = stablehlo.slice %125 [2:3, 0:2, 0:35] : (tensor<8x2x35xi64>) -> tensor<1x2x35xi64> + %286 = stablehlo.reshape %285 : (tensor<1x2x35xi64>) -> tensor<2x35xi64> + %287 = stablehlo.custom_call @byteir.non_zero(%286) {byteir_attrs = {}} : (tensor<2x35xi64>) -> tensor + %dim_100 = tensor.dim %287, %c0 : tensor + %288 = arith.index_cast %dim_100 : index to i64 + %from_elements_101 = tensor.from_elements %288, %c1_i64 : tensor<2xi64> + %289 = stablehlo.real_dynamic_slice %287, %c_43, %from_elements_101, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_102 = tensor.dim %289, %c0 : tensor + %290 = arith.index_cast %dim_102 : index to i64 + %from_elements_103 = tensor.from_elements %290 : tensor<1xi64> + %291 = stablehlo.dynamic_reshape %289, %from_elements_103 : (tensor, tensor<1xi64>) -> tensor + %from_elements_104 = tensor.from_elements %288, %c2_i64 : tensor<2xi64> + %292 = stablehlo.real_dynamic_slice %287, %c_46, %from_elements_104, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_105 = tensor.dim %292, %c0 : tensor + %293 = arith.index_cast %dim_105 : index to i64 + %from_elements_106 = tensor.from_elements %293 : tensor<1xi64> + %294 = stablehlo.dynamic_reshape %292, %from_elements_106 : (tensor, tensor<1xi64>) -> tensor + %dim_107 = tensor.dim %294, %c0 : tensor + %295 = arith.index_cast %dim_107 : index to i64 + %from_elements_108 = tensor.from_elements %295, %c1_i64 : tensor<2xi64> + %296 = stablehlo.dynamic_reshape %294, %from_elements_108 : (tensor, tensor<2xi64>) -> tensor + %dim_109 = tensor.dim %296, %c0 : tensor + %297 = arith.index_cast %dim_109 : index to i64 + %from_elements_110 = tensor.from_elements %c1_i64, %297, %c32_i64 : tensor<3xi64> + %298 = stablehlo.dynamic_broadcast_in_dim %158, %from_elements_110, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_111 = tensor.dim %298, %c1 : tensor<1x?x32xi64> + %299 = arith.index_cast %dim_111 : index to i64 + %from_elements_112 = tensor.from_elements %c1_i64, %299, %c32_i64, %c1_i64 : tensor<4xi64> + %300 = stablehlo.dynamic_reshape %298, %from_elements_112 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %301 = stablehlo.dynamic_broadcast_in_dim %296, %from_elements_110, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_113 = tensor.dim %301, %c1 : tensor<1x?x32xi64> + %302 = arith.index_cast %dim_113 : index to i64 + %from_elements_114 = tensor.from_elements %c1_i64, %302, %c32_i64, %c1_i64 : tensor<4xi64> + %303 = stablehlo.dynamic_reshape %301, %from_elements_114 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %304 = stablehlo.dynamic_broadcast_in_dim %145, %from_elements_110, dims = [2] : (tensor<32xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_115 = tensor.dim %304, %c1 : tensor<1x?x32xi64> + %305 = arith.index_cast %dim_115 : index to i64 + %from_elements_116 = tensor.from_elements %c1_i64, %305, %c32_i64, %c1_i64 : tensor<4xi64> + %306 = stablehlo.dynamic_reshape %304, %from_elements_116 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %307 = stablehlo.concatenate %300, %303, %306, dim = 3 : (tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>) -> tensor<1x?x32x3xi64> + %308 = "stablehlo.gather"(%136, %307) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x35x32xf32>, tensor<1x?x32x3xi64>) -> tensor<1x?x32xf32> + %309 = shape.shape_of %308 : tensor<1x?x32xf32> -> tensor<3xindex> + %310 = shape.num_elements %309 : tensor<3xindex> -> index + %311 = stablehlo.compute_reshape_shape %310, %c_45 : (index, tensor<2xi64>) -> tensor<2xi64> + %312 = stablehlo.dynamic_reshape %308, %311 : (tensor<1x?x32xf32>, tensor<2xi64>) -> tensor + %313 = stablehlo.transpose %cst_28, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %314 = stablehlo.dot %312, %313 : (tensor, tensor<32x14336xf32>) -> tensor + %315 = stablehlo.logistic %314 : tensor + %316 = shape.shape_of %315 : tensor -> tensor<2xindex> + %317 = shape.shape_of %314 : tensor -> tensor<2xindex> + %318 = shape.cstr_broadcastable %316, %317 : tensor<2xindex>, tensor<2xindex> + %319 = shape.assuming %318 -> (tensor) { + %695 = shape.broadcast %316, %317 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %315, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %314, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %320 = stablehlo.transpose %cst_27, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %321 = stablehlo.dot %312, %320 : (tensor, tensor<32x14336xf32>) -> tensor + %322 = shape.shape_of %319 : tensor -> tensor<2xindex> + %323 = shape.shape_of %321 : tensor -> tensor<2xindex> + %324 = shape.cstr_broadcastable %322, %323 : tensor<2xindex>, tensor<2xindex> + %325 = shape.assuming %324 -> (tensor) { + %695 = shape.broadcast %322, %323 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %319, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %321, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %326 = stablehlo.transpose %cst_26, dims = [1, 0] : (tensor<32x14336xf32>) -> tensor<14336x32xf32> + %327 = stablehlo.dot %325, %326 : (tensor, tensor<14336x32xf32>) -> tensor + %dim_117 = tensor.dim %294, %c0 : tensor + %328 = arith.index_cast %dim_117 : index to i64 + %from_elements_118 = tensor.from_elements %328, %c1_i64 : tensor<2xi64> + %329 = stablehlo.dynamic_reshape %294, %from_elements_118 : (tensor, tensor<2xi64>) -> tensor + %dim_119 = tensor.dim %291, %c0 : tensor + %330 = arith.index_cast %dim_119 : index to i64 + %from_elements_120 = tensor.from_elements %330, %c1_i64 : tensor<2xi64> + %331 = stablehlo.dynamic_reshape %291, %from_elements_120 : (tensor, tensor<2xi64>) -> tensor + %332 = stablehlo.concatenate %329, %331, dim = 1 : (tensor, tensor) -> tensor + %333 = "stablehlo.gather"(%190, %332) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<35x2x1xf32>, tensor) -> tensor + %334 = shape.shape_of %327 : tensor -> tensor<2xindex> + %335 = shape.shape_of %333 : tensor -> tensor<2xindex> + %336 = shape.cstr_broadcastable %334, %335 : tensor<2xindex>, tensor<2xindex> + %337 = shape.assuming %336 -> (tensor) { + %695 = shape.broadcast %334, %335 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %327, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %333, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %338 = shape.shape_of %337 : tensor -> tensor<2xindex> + %339 = stablehlo.dynamic_broadcast_in_dim %337, %338, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %340 = stablehlo.dynamic_broadcast_in_dim %201, %338, dims = [] : (tensor, tensor<2xindex>) -> tensor + %341 = stablehlo.multiply %339, %340 : tensor + %dim_121 = tensor.dim %296, %c0 : tensor + %342 = arith.index_cast %dim_121 : index to i64 + %dim_122 = tensor.dim %337, %c0 : tensor + %343 = arith.index_cast %dim_122 : index to i64 + %344 = arith.maxsi %342, %343 : i64 + %345 = arith.index_cast %344 : i64 to index + %from_elements_123 = tensor.from_elements %345, %c32 : tensor<2xindex> + %346 = stablehlo.dynamic_broadcast_in_dim %296, %from_elements_123, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_124 = tensor.dim %346, %c0 : tensor + %347 = arith.index_cast %dim_124 : index to i64 + %from_elements_125 = tensor.from_elements %347, %c32_i64 : tensor<2xi64> + %348 = stablehlo.real_dynamic_slice %341, %c_43, %from_elements_125, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_126 = tensor.from_elements %347, %c32_i64, %c1_i64 : tensor<3xi64> + %349 = stablehlo.dynamic_reshape %346, %from_elements_126 : (tensor, tensor<3xi64>) -> tensor + %350 = stablehlo.dynamic_iota %from_elements_126, dim = 1 : (tensor<3xi64>) -> tensor + %351 = stablehlo.concatenate %349, %350, dim = 2 : (tensor, tensor) -> tensor + %352 = "stablehlo.scatter"(%284, %351, %348) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %695 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %695 : tensor + }) : (tensor<35x32xf32>, tensor, tensor) -> tensor<35x32xf32> + %353 = stablehlo.slice %125 [3:4, 0:2, 0:35] : (tensor<8x2x35xi64>) -> tensor<1x2x35xi64> + %354 = stablehlo.reshape %353 : (tensor<1x2x35xi64>) -> tensor<2x35xi64> + %355 = stablehlo.custom_call @byteir.non_zero(%354) {byteir_attrs = {}} : (tensor<2x35xi64>) -> tensor + %dim_127 = tensor.dim %355, %c0 : tensor + %356 = arith.index_cast %dim_127 : index to i64 + %from_elements_128 = tensor.from_elements %356, %c1_i64 : tensor<2xi64> + %357 = stablehlo.real_dynamic_slice %355, %c_43, %from_elements_128, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_129 = tensor.dim %357, %c0 : tensor + %358 = arith.index_cast %dim_129 : index to i64 + %from_elements_130 = tensor.from_elements %358 : tensor<1xi64> + %359 = stablehlo.dynamic_reshape %357, %from_elements_130 : (tensor, tensor<1xi64>) -> tensor + %from_elements_131 = tensor.from_elements %356, %c2_i64 : tensor<2xi64> + %360 = stablehlo.real_dynamic_slice %355, %c_46, %from_elements_131, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_132 = tensor.dim %360, %c0 : tensor + %361 = arith.index_cast %dim_132 : index to i64 + %from_elements_133 = tensor.from_elements %361 : tensor<1xi64> + %362 = stablehlo.dynamic_reshape %360, %from_elements_133 : (tensor, tensor<1xi64>) -> tensor + %dim_134 = tensor.dim %362, %c0 : tensor + %363 = arith.index_cast %dim_134 : index to i64 + %from_elements_135 = tensor.from_elements %363, %c1_i64 : tensor<2xi64> + %364 = stablehlo.dynamic_reshape %362, %from_elements_135 : (tensor, tensor<2xi64>) -> tensor + %dim_136 = tensor.dim %364, %c0 : tensor + %365 = arith.index_cast %dim_136 : index to i64 + %from_elements_137 = tensor.from_elements %c1_i64, %365, %c32_i64 : tensor<3xi64> + %366 = stablehlo.dynamic_broadcast_in_dim %158, %from_elements_137, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_138 = tensor.dim %366, %c1 : tensor<1x?x32xi64> + %367 = arith.index_cast %dim_138 : index to i64 + %from_elements_139 = tensor.from_elements %c1_i64, %367, %c32_i64, %c1_i64 : tensor<4xi64> + %368 = stablehlo.dynamic_reshape %366, %from_elements_139 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %369 = stablehlo.dynamic_broadcast_in_dim %364, %from_elements_137, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_140 = tensor.dim %369, %c1 : tensor<1x?x32xi64> + %370 = arith.index_cast %dim_140 : index to i64 + %from_elements_141 = tensor.from_elements %c1_i64, %370, %c32_i64, %c1_i64 : tensor<4xi64> + %371 = stablehlo.dynamic_reshape %369, %from_elements_141 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %372 = stablehlo.dynamic_broadcast_in_dim %145, %from_elements_137, dims = [2] : (tensor<32xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_142 = tensor.dim %372, %c1 : tensor<1x?x32xi64> + %373 = arith.index_cast %dim_142 : index to i64 + %from_elements_143 = tensor.from_elements %c1_i64, %373, %c32_i64, %c1_i64 : tensor<4xi64> + %374 = stablehlo.dynamic_reshape %372, %from_elements_143 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %375 = stablehlo.concatenate %368, %371, %374, dim = 3 : (tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>) -> tensor<1x?x32x3xi64> + %376 = "stablehlo.gather"(%136, %375) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x35x32xf32>, tensor<1x?x32x3xi64>) -> tensor<1x?x32xf32> + %377 = shape.shape_of %376 : tensor<1x?x32xf32> -> tensor<3xindex> + %378 = shape.num_elements %377 : tensor<3xindex> -> index + %379 = stablehlo.compute_reshape_shape %378, %c_45 : (index, tensor<2xi64>) -> tensor<2xi64> + %380 = stablehlo.dynamic_reshape %376, %379 : (tensor<1x?x32xf32>, tensor<2xi64>) -> tensor + %381 = stablehlo.transpose %cst_25, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %382 = stablehlo.dot %380, %381 : (tensor, tensor<32x14336xf32>) -> tensor + %383 = stablehlo.logistic %382 : tensor + %384 = shape.shape_of %383 : tensor -> tensor<2xindex> + %385 = shape.shape_of %382 : tensor -> tensor<2xindex> + %386 = shape.cstr_broadcastable %384, %385 : tensor<2xindex>, tensor<2xindex> + %387 = shape.assuming %386 -> (tensor) { + %695 = shape.broadcast %384, %385 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %383, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %382, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %388 = stablehlo.transpose %cst_24, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %389 = stablehlo.dot %380, %388 : (tensor, tensor<32x14336xf32>) -> tensor + %390 = shape.shape_of %387 : tensor -> tensor<2xindex> + %391 = shape.shape_of %389 : tensor -> tensor<2xindex> + %392 = shape.cstr_broadcastable %390, %391 : tensor<2xindex>, tensor<2xindex> + %393 = shape.assuming %392 -> (tensor) { + %695 = shape.broadcast %390, %391 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %387, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %389, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %394 = stablehlo.transpose %cst_23, dims = [1, 0] : (tensor<32x14336xf32>) -> tensor<14336x32xf32> + %395 = stablehlo.dot %393, %394 : (tensor, tensor<14336x32xf32>) -> tensor + %dim_144 = tensor.dim %362, %c0 : tensor + %396 = arith.index_cast %dim_144 : index to i64 + %from_elements_145 = tensor.from_elements %396, %c1_i64 : tensor<2xi64> + %397 = stablehlo.dynamic_reshape %362, %from_elements_145 : (tensor, tensor<2xi64>) -> tensor + %dim_146 = tensor.dim %359, %c0 : tensor + %398 = arith.index_cast %dim_146 : index to i64 + %from_elements_147 = tensor.from_elements %398, %c1_i64 : tensor<2xi64> + %399 = stablehlo.dynamic_reshape %359, %from_elements_147 : (tensor, tensor<2xi64>) -> tensor + %400 = stablehlo.concatenate %397, %399, dim = 1 : (tensor, tensor) -> tensor + %401 = "stablehlo.gather"(%190, %400) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<35x2x1xf32>, tensor) -> tensor + %402 = shape.shape_of %395 : tensor -> tensor<2xindex> + %403 = shape.shape_of %401 : tensor -> tensor<2xindex> + %404 = shape.cstr_broadcastable %402, %403 : tensor<2xindex>, tensor<2xindex> + %405 = shape.assuming %404 -> (tensor) { + %695 = shape.broadcast %402, %403 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %395, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %401, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %406 = shape.shape_of %405 : tensor -> tensor<2xindex> + %407 = stablehlo.dynamic_broadcast_in_dim %405, %406, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %408 = stablehlo.dynamic_broadcast_in_dim %201, %406, dims = [] : (tensor, tensor<2xindex>) -> tensor + %409 = stablehlo.multiply %407, %408 : tensor + %dim_148 = tensor.dim %364, %c0 : tensor + %410 = arith.index_cast %dim_148 : index to i64 + %dim_149 = tensor.dim %405, %c0 : tensor + %411 = arith.index_cast %dim_149 : index to i64 + %412 = arith.maxsi %410, %411 : i64 + %413 = arith.index_cast %412 : i64 to index + %from_elements_150 = tensor.from_elements %413, %c32 : tensor<2xindex> + %414 = stablehlo.dynamic_broadcast_in_dim %364, %from_elements_150, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_151 = tensor.dim %414, %c0 : tensor + %415 = arith.index_cast %dim_151 : index to i64 + %from_elements_152 = tensor.from_elements %415, %c32_i64 : tensor<2xi64> + %416 = stablehlo.real_dynamic_slice %409, %c_43, %from_elements_152, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_153 = tensor.from_elements %415, %c32_i64, %c1_i64 : tensor<3xi64> + %417 = stablehlo.dynamic_reshape %414, %from_elements_153 : (tensor, tensor<3xi64>) -> tensor + %418 = stablehlo.dynamic_iota %from_elements_153, dim = 1 : (tensor<3xi64>) -> tensor + %419 = stablehlo.concatenate %417, %418, dim = 2 : (tensor, tensor) -> tensor + %420 = "stablehlo.scatter"(%352, %419, %416) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %695 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %695 : tensor + }) : (tensor<35x32xf32>, tensor, tensor) -> tensor<35x32xf32> + %421 = stablehlo.slice %125 [4:5, 0:2, 0:35] : (tensor<8x2x35xi64>) -> tensor<1x2x35xi64> + %422 = stablehlo.reshape %421 : (tensor<1x2x35xi64>) -> tensor<2x35xi64> + %423 = stablehlo.custom_call @byteir.non_zero(%422) {byteir_attrs = {}} : (tensor<2x35xi64>) -> tensor + %dim_154 = tensor.dim %423, %c0 : tensor + %424 = arith.index_cast %dim_154 : index to i64 + %from_elements_155 = tensor.from_elements %424, %c1_i64 : tensor<2xi64> + %425 = stablehlo.real_dynamic_slice %423, %c_43, %from_elements_155, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_156 = tensor.dim %425, %c0 : tensor + %426 = arith.index_cast %dim_156 : index to i64 + %from_elements_157 = tensor.from_elements %426 : tensor<1xi64> + %427 = stablehlo.dynamic_reshape %425, %from_elements_157 : (tensor, tensor<1xi64>) -> tensor + %from_elements_158 = tensor.from_elements %424, %c2_i64 : tensor<2xi64> + %428 = stablehlo.real_dynamic_slice %423, %c_46, %from_elements_158, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_159 = tensor.dim %428, %c0 : tensor + %429 = arith.index_cast %dim_159 : index to i64 + %from_elements_160 = tensor.from_elements %429 : tensor<1xi64> + %430 = stablehlo.dynamic_reshape %428, %from_elements_160 : (tensor, tensor<1xi64>) -> tensor + %dim_161 = tensor.dim %430, %c0 : tensor + %431 = arith.index_cast %dim_161 : index to i64 + %from_elements_162 = tensor.from_elements %431, %c1_i64 : tensor<2xi64> + %432 = stablehlo.dynamic_reshape %430, %from_elements_162 : (tensor, tensor<2xi64>) -> tensor + %dim_163 = tensor.dim %432, %c0 : tensor + %433 = arith.index_cast %dim_163 : index to i64 + %from_elements_164 = tensor.from_elements %c1_i64, %433, %c32_i64 : tensor<3xi64> + %434 = stablehlo.dynamic_broadcast_in_dim %158, %from_elements_164, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_165 = tensor.dim %434, %c1 : tensor<1x?x32xi64> + %435 = arith.index_cast %dim_165 : index to i64 + %from_elements_166 = tensor.from_elements %c1_i64, %435, %c32_i64, %c1_i64 : tensor<4xi64> + %436 = stablehlo.dynamic_reshape %434, %from_elements_166 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %437 = stablehlo.dynamic_broadcast_in_dim %432, %from_elements_164, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_167 = tensor.dim %437, %c1 : tensor<1x?x32xi64> + %438 = arith.index_cast %dim_167 : index to i64 + %from_elements_168 = tensor.from_elements %c1_i64, %438, %c32_i64, %c1_i64 : tensor<4xi64> + %439 = stablehlo.dynamic_reshape %437, %from_elements_168 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %440 = stablehlo.dynamic_broadcast_in_dim %145, %from_elements_164, dims = [2] : (tensor<32xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_169 = tensor.dim %440, %c1 : tensor<1x?x32xi64> + %441 = arith.index_cast %dim_169 : index to i64 + %from_elements_170 = tensor.from_elements %c1_i64, %441, %c32_i64, %c1_i64 : tensor<4xi64> + %442 = stablehlo.dynamic_reshape %440, %from_elements_170 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %443 = stablehlo.concatenate %436, %439, %442, dim = 3 : (tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>) -> tensor<1x?x32x3xi64> + %444 = "stablehlo.gather"(%136, %443) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x35x32xf32>, tensor<1x?x32x3xi64>) -> tensor<1x?x32xf32> + %445 = shape.shape_of %444 : tensor<1x?x32xf32> -> tensor<3xindex> + %446 = shape.num_elements %445 : tensor<3xindex> -> index + %447 = stablehlo.compute_reshape_shape %446, %c_45 : (index, tensor<2xi64>) -> tensor<2xi64> + %448 = stablehlo.dynamic_reshape %444, %447 : (tensor<1x?x32xf32>, tensor<2xi64>) -> tensor + %449 = stablehlo.transpose %cst_22, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %450 = stablehlo.dot %448, %449 : (tensor, tensor<32x14336xf32>) -> tensor + %451 = stablehlo.logistic %450 : tensor + %452 = shape.shape_of %451 : tensor -> tensor<2xindex> + %453 = shape.shape_of %450 : tensor -> tensor<2xindex> + %454 = shape.cstr_broadcastable %452, %453 : tensor<2xindex>, tensor<2xindex> + %455 = shape.assuming %454 -> (tensor) { + %695 = shape.broadcast %452, %453 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %451, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %450, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %456 = stablehlo.transpose %cst_21, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %457 = stablehlo.dot %448, %456 : (tensor, tensor<32x14336xf32>) -> tensor + %458 = shape.shape_of %455 : tensor -> tensor<2xindex> + %459 = shape.shape_of %457 : tensor -> tensor<2xindex> + %460 = shape.cstr_broadcastable %458, %459 : tensor<2xindex>, tensor<2xindex> + %461 = shape.assuming %460 -> (tensor) { + %695 = shape.broadcast %458, %459 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %455, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %457, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %462 = stablehlo.transpose %cst_20, dims = [1, 0] : (tensor<32x14336xf32>) -> tensor<14336x32xf32> + %463 = stablehlo.dot %461, %462 : (tensor, tensor<14336x32xf32>) -> tensor + %dim_171 = tensor.dim %430, %c0 : tensor + %464 = arith.index_cast %dim_171 : index to i64 + %from_elements_172 = tensor.from_elements %464, %c1_i64 : tensor<2xi64> + %465 = stablehlo.dynamic_reshape %430, %from_elements_172 : (tensor, tensor<2xi64>) -> tensor + %dim_173 = tensor.dim %427, %c0 : tensor + %466 = arith.index_cast %dim_173 : index to i64 + %from_elements_174 = tensor.from_elements %466, %c1_i64 : tensor<2xi64> + %467 = stablehlo.dynamic_reshape %427, %from_elements_174 : (tensor, tensor<2xi64>) -> tensor + %468 = stablehlo.concatenate %465, %467, dim = 1 : (tensor, tensor) -> tensor + %469 = "stablehlo.gather"(%190, %468) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<35x2x1xf32>, tensor) -> tensor + %470 = shape.shape_of %463 : tensor -> tensor<2xindex> + %471 = shape.shape_of %469 : tensor -> tensor<2xindex> + %472 = shape.cstr_broadcastable %470, %471 : tensor<2xindex>, tensor<2xindex> + %473 = shape.assuming %472 -> (tensor) { + %695 = shape.broadcast %470, %471 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %463, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %469, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %474 = shape.shape_of %473 : tensor -> tensor<2xindex> + %475 = stablehlo.dynamic_broadcast_in_dim %473, %474, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %476 = stablehlo.dynamic_broadcast_in_dim %201, %474, dims = [] : (tensor, tensor<2xindex>) -> tensor + %477 = stablehlo.multiply %475, %476 : tensor + %dim_175 = tensor.dim %432, %c0 : tensor + %478 = arith.index_cast %dim_175 : index to i64 + %dim_176 = tensor.dim %473, %c0 : tensor + %479 = arith.index_cast %dim_176 : index to i64 + %480 = arith.maxsi %478, %479 : i64 + %481 = arith.index_cast %480 : i64 to index + %from_elements_177 = tensor.from_elements %481, %c32 : tensor<2xindex> + %482 = stablehlo.dynamic_broadcast_in_dim %432, %from_elements_177, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_178 = tensor.dim %482, %c0 : tensor + %483 = arith.index_cast %dim_178 : index to i64 + %from_elements_179 = tensor.from_elements %483, %c32_i64 : tensor<2xi64> + %484 = stablehlo.real_dynamic_slice %477, %c_43, %from_elements_179, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_180 = tensor.from_elements %483, %c32_i64, %c1_i64 : tensor<3xi64> + %485 = stablehlo.dynamic_reshape %482, %from_elements_180 : (tensor, tensor<3xi64>) -> tensor + %486 = stablehlo.dynamic_iota %from_elements_180, dim = 1 : (tensor<3xi64>) -> tensor + %487 = stablehlo.concatenate %485, %486, dim = 2 : (tensor, tensor) -> tensor + %488 = "stablehlo.scatter"(%420, %487, %484) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %695 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %695 : tensor + }) : (tensor<35x32xf32>, tensor, tensor) -> tensor<35x32xf32> + %489 = stablehlo.slice %125 [5:6, 0:2, 0:35] : (tensor<8x2x35xi64>) -> tensor<1x2x35xi64> + %490 = stablehlo.reshape %489 : (tensor<1x2x35xi64>) -> tensor<2x35xi64> + %491 = stablehlo.custom_call @byteir.non_zero(%490) {byteir_attrs = {}} : (tensor<2x35xi64>) -> tensor + %dim_181 = tensor.dim %491, %c0 : tensor + %492 = arith.index_cast %dim_181 : index to i64 + %from_elements_182 = tensor.from_elements %492, %c1_i64 : tensor<2xi64> + %493 = stablehlo.real_dynamic_slice %491, %c_43, %from_elements_182, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_183 = tensor.dim %493, %c0 : tensor + %494 = arith.index_cast %dim_183 : index to i64 + %from_elements_184 = tensor.from_elements %494 : tensor<1xi64> + %495 = stablehlo.dynamic_reshape %493, %from_elements_184 : (tensor, tensor<1xi64>) -> tensor + %from_elements_185 = tensor.from_elements %492, %c2_i64 : tensor<2xi64> + %496 = stablehlo.real_dynamic_slice %491, %c_46, %from_elements_185, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_186 = tensor.dim %496, %c0 : tensor + %497 = arith.index_cast %dim_186 : index to i64 + %from_elements_187 = tensor.from_elements %497 : tensor<1xi64> + %498 = stablehlo.dynamic_reshape %496, %from_elements_187 : (tensor, tensor<1xi64>) -> tensor + %dim_188 = tensor.dim %498, %c0 : tensor + %499 = arith.index_cast %dim_188 : index to i64 + %from_elements_189 = tensor.from_elements %499, %c1_i64 : tensor<2xi64> + %500 = stablehlo.dynamic_reshape %498, %from_elements_189 : (tensor, tensor<2xi64>) -> tensor + %dim_190 = tensor.dim %500, %c0 : tensor + %501 = arith.index_cast %dim_190 : index to i64 + %from_elements_191 = tensor.from_elements %c1_i64, %501, %c32_i64 : tensor<3xi64> + %502 = stablehlo.dynamic_broadcast_in_dim %158, %from_elements_191, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_192 = tensor.dim %502, %c1 : tensor<1x?x32xi64> + %503 = arith.index_cast %dim_192 : index to i64 + %from_elements_193 = tensor.from_elements %c1_i64, %503, %c32_i64, %c1_i64 : tensor<4xi64> + %504 = stablehlo.dynamic_reshape %502, %from_elements_193 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %505 = stablehlo.dynamic_broadcast_in_dim %500, %from_elements_191, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_194 = tensor.dim %505, %c1 : tensor<1x?x32xi64> + %506 = arith.index_cast %dim_194 : index to i64 + %from_elements_195 = tensor.from_elements %c1_i64, %506, %c32_i64, %c1_i64 : tensor<4xi64> + %507 = stablehlo.dynamic_reshape %505, %from_elements_195 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %508 = stablehlo.dynamic_broadcast_in_dim %145, %from_elements_191, dims = [2] : (tensor<32xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_196 = tensor.dim %508, %c1 : tensor<1x?x32xi64> + %509 = arith.index_cast %dim_196 : index to i64 + %from_elements_197 = tensor.from_elements %c1_i64, %509, %c32_i64, %c1_i64 : tensor<4xi64> + %510 = stablehlo.dynamic_reshape %508, %from_elements_197 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %511 = stablehlo.concatenate %504, %507, %510, dim = 3 : (tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>) -> tensor<1x?x32x3xi64> + %512 = "stablehlo.gather"(%136, %511) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x35x32xf32>, tensor<1x?x32x3xi64>) -> tensor<1x?x32xf32> + %513 = shape.shape_of %512 : tensor<1x?x32xf32> -> tensor<3xindex> + %514 = shape.num_elements %513 : tensor<3xindex> -> index + %515 = stablehlo.compute_reshape_shape %514, %c_45 : (index, tensor<2xi64>) -> tensor<2xi64> + %516 = stablehlo.dynamic_reshape %512, %515 : (tensor<1x?x32xf32>, tensor<2xi64>) -> tensor + %517 = stablehlo.transpose %cst_19, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %518 = stablehlo.dot %516, %517 : (tensor, tensor<32x14336xf32>) -> tensor + %519 = stablehlo.logistic %518 : tensor + %520 = shape.shape_of %519 : tensor -> tensor<2xindex> + %521 = shape.shape_of %518 : tensor -> tensor<2xindex> + %522 = shape.cstr_broadcastable %520, %521 : tensor<2xindex>, tensor<2xindex> + %523 = shape.assuming %522 -> (tensor) { + %695 = shape.broadcast %520, %521 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %519, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %518, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %524 = stablehlo.transpose %cst_18, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %525 = stablehlo.dot %516, %524 : (tensor, tensor<32x14336xf32>) -> tensor + %526 = shape.shape_of %523 : tensor -> tensor<2xindex> + %527 = shape.shape_of %525 : tensor -> tensor<2xindex> + %528 = shape.cstr_broadcastable %526, %527 : tensor<2xindex>, tensor<2xindex> + %529 = shape.assuming %528 -> (tensor) { + %695 = shape.broadcast %526, %527 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %523, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %525, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %530 = stablehlo.transpose %cst_17, dims = [1, 0] : (tensor<32x14336xf32>) -> tensor<14336x32xf32> + %531 = stablehlo.dot %529, %530 : (tensor, tensor<14336x32xf32>) -> tensor + %dim_198 = tensor.dim %498, %c0 : tensor + %532 = arith.index_cast %dim_198 : index to i64 + %from_elements_199 = tensor.from_elements %532, %c1_i64 : tensor<2xi64> + %533 = stablehlo.dynamic_reshape %498, %from_elements_199 : (tensor, tensor<2xi64>) -> tensor + %dim_200 = tensor.dim %495, %c0 : tensor + %534 = arith.index_cast %dim_200 : index to i64 + %from_elements_201 = tensor.from_elements %534, %c1_i64 : tensor<2xi64> + %535 = stablehlo.dynamic_reshape %495, %from_elements_201 : (tensor, tensor<2xi64>) -> tensor + %536 = stablehlo.concatenate %533, %535, dim = 1 : (tensor, tensor) -> tensor + %537 = "stablehlo.gather"(%190, %536) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<35x2x1xf32>, tensor) -> tensor + %538 = shape.shape_of %531 : tensor -> tensor<2xindex> + %539 = shape.shape_of %537 : tensor -> tensor<2xindex> + %540 = shape.cstr_broadcastable %538, %539 : tensor<2xindex>, tensor<2xindex> + %541 = shape.assuming %540 -> (tensor) { + %695 = shape.broadcast %538, %539 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %531, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %537, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %542 = shape.shape_of %541 : tensor -> tensor<2xindex> + %543 = stablehlo.dynamic_broadcast_in_dim %541, %542, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %544 = stablehlo.dynamic_broadcast_in_dim %201, %542, dims = [] : (tensor, tensor<2xindex>) -> tensor + %545 = stablehlo.multiply %543, %544 : tensor + %dim_202 = tensor.dim %500, %c0 : tensor + %546 = arith.index_cast %dim_202 : index to i64 + %dim_203 = tensor.dim %541, %c0 : tensor + %547 = arith.index_cast %dim_203 : index to i64 + %548 = arith.maxsi %546, %547 : i64 + %549 = arith.index_cast %548 : i64 to index + %from_elements_204 = tensor.from_elements %549, %c32 : tensor<2xindex> + %550 = stablehlo.dynamic_broadcast_in_dim %500, %from_elements_204, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_205 = tensor.dim %550, %c0 : tensor + %551 = arith.index_cast %dim_205 : index to i64 + %from_elements_206 = tensor.from_elements %551, %c32_i64 : tensor<2xi64> + %552 = stablehlo.real_dynamic_slice %545, %c_43, %from_elements_206, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_207 = tensor.from_elements %551, %c32_i64, %c1_i64 : tensor<3xi64> + %553 = stablehlo.dynamic_reshape %550, %from_elements_207 : (tensor, tensor<3xi64>) -> tensor + %554 = stablehlo.dynamic_iota %from_elements_207, dim = 1 : (tensor<3xi64>) -> tensor + %555 = stablehlo.concatenate %553, %554, dim = 2 : (tensor, tensor) -> tensor + %556 = "stablehlo.scatter"(%488, %555, %552) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %695 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %695 : tensor + }) : (tensor<35x32xf32>, tensor, tensor) -> tensor<35x32xf32> + %557 = stablehlo.slice %125 [6:7, 0:2, 0:35] : (tensor<8x2x35xi64>) -> tensor<1x2x35xi64> + %558 = stablehlo.reshape %557 : (tensor<1x2x35xi64>) -> tensor<2x35xi64> + %559 = stablehlo.custom_call @byteir.non_zero(%558) {byteir_attrs = {}} : (tensor<2x35xi64>) -> tensor + %dim_208 = tensor.dim %559, %c0 : tensor + %560 = arith.index_cast %dim_208 : index to i64 + %from_elements_209 = tensor.from_elements %560, %c1_i64 : tensor<2xi64> + %561 = stablehlo.real_dynamic_slice %559, %c_43, %from_elements_209, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_210 = tensor.dim %561, %c0 : tensor + %562 = arith.index_cast %dim_210 : index to i64 + %from_elements_211 = tensor.from_elements %562 : tensor<1xi64> + %563 = stablehlo.dynamic_reshape %561, %from_elements_211 : (tensor, tensor<1xi64>) -> tensor + %from_elements_212 = tensor.from_elements %560, %c2_i64 : tensor<2xi64> + %564 = stablehlo.real_dynamic_slice %559, %c_46, %from_elements_212, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_213 = tensor.dim %564, %c0 : tensor + %565 = arith.index_cast %dim_213 : index to i64 + %from_elements_214 = tensor.from_elements %565 : tensor<1xi64> + %566 = stablehlo.dynamic_reshape %564, %from_elements_214 : (tensor, tensor<1xi64>) -> tensor + %dim_215 = tensor.dim %566, %c0 : tensor + %567 = arith.index_cast %dim_215 : index to i64 + %from_elements_216 = tensor.from_elements %567, %c1_i64 : tensor<2xi64> + %568 = stablehlo.dynamic_reshape %566, %from_elements_216 : (tensor, tensor<2xi64>) -> tensor + %dim_217 = tensor.dim %568, %c0 : tensor + %569 = arith.index_cast %dim_217 : index to i64 + %from_elements_218 = tensor.from_elements %c1_i64, %569, %c32_i64 : tensor<3xi64> + %570 = stablehlo.dynamic_broadcast_in_dim %158, %from_elements_218, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_219 = tensor.dim %570, %c1 : tensor<1x?x32xi64> + %571 = arith.index_cast %dim_219 : index to i64 + %from_elements_220 = tensor.from_elements %c1_i64, %571, %c32_i64, %c1_i64 : tensor<4xi64> + %572 = stablehlo.dynamic_reshape %570, %from_elements_220 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %573 = stablehlo.dynamic_broadcast_in_dim %568, %from_elements_218, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_221 = tensor.dim %573, %c1 : tensor<1x?x32xi64> + %574 = arith.index_cast %dim_221 : index to i64 + %from_elements_222 = tensor.from_elements %c1_i64, %574, %c32_i64, %c1_i64 : tensor<4xi64> + %575 = stablehlo.dynamic_reshape %573, %from_elements_222 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %576 = stablehlo.dynamic_broadcast_in_dim %145, %from_elements_218, dims = [2] : (tensor<32xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_223 = tensor.dim %576, %c1 : tensor<1x?x32xi64> + %577 = arith.index_cast %dim_223 : index to i64 + %from_elements_224 = tensor.from_elements %c1_i64, %577, %c32_i64, %c1_i64 : tensor<4xi64> + %578 = stablehlo.dynamic_reshape %576, %from_elements_224 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %579 = stablehlo.concatenate %572, %575, %578, dim = 3 : (tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>) -> tensor<1x?x32x3xi64> + %580 = "stablehlo.gather"(%136, %579) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x35x32xf32>, tensor<1x?x32x3xi64>) -> tensor<1x?x32xf32> + %581 = shape.shape_of %580 : tensor<1x?x32xf32> -> tensor<3xindex> + %582 = shape.num_elements %581 : tensor<3xindex> -> index + %583 = stablehlo.compute_reshape_shape %582, %c_45 : (index, tensor<2xi64>) -> tensor<2xi64> + %584 = stablehlo.dynamic_reshape %580, %583 : (tensor<1x?x32xf32>, tensor<2xi64>) -> tensor + %585 = stablehlo.transpose %cst_16, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %586 = stablehlo.dot %584, %585 : (tensor, tensor<32x14336xf32>) -> tensor + %587 = stablehlo.logistic %586 : tensor + %588 = shape.shape_of %587 : tensor -> tensor<2xindex> + %589 = shape.shape_of %586 : tensor -> tensor<2xindex> + %590 = shape.cstr_broadcastable %588, %589 : tensor<2xindex>, tensor<2xindex> + %591 = shape.assuming %590 -> (tensor) { + %695 = shape.broadcast %588, %589 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %587, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %586, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %592 = stablehlo.transpose %cst_15, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %593 = stablehlo.dot %584, %592 : (tensor, tensor<32x14336xf32>) -> tensor + %594 = shape.shape_of %591 : tensor -> tensor<2xindex> + %595 = shape.shape_of %593 : tensor -> tensor<2xindex> + %596 = shape.cstr_broadcastable %594, %595 : tensor<2xindex>, tensor<2xindex> + %597 = shape.assuming %596 -> (tensor) { + %695 = shape.broadcast %594, %595 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %591, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %593, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %598 = stablehlo.transpose %cst_14, dims = [1, 0] : (tensor<32x14336xf32>) -> tensor<14336x32xf32> + %599 = stablehlo.dot %597, %598 : (tensor, tensor<14336x32xf32>) -> tensor + %dim_225 = tensor.dim %566, %c0 : tensor + %600 = arith.index_cast %dim_225 : index to i64 + %from_elements_226 = tensor.from_elements %600, %c1_i64 : tensor<2xi64> + %601 = stablehlo.dynamic_reshape %566, %from_elements_226 : (tensor, tensor<2xi64>) -> tensor + %dim_227 = tensor.dim %563, %c0 : tensor + %602 = arith.index_cast %dim_227 : index to i64 + %from_elements_228 = tensor.from_elements %602, %c1_i64 : tensor<2xi64> + %603 = stablehlo.dynamic_reshape %563, %from_elements_228 : (tensor, tensor<2xi64>) -> tensor + %604 = stablehlo.concatenate %601, %603, dim = 1 : (tensor, tensor) -> tensor + %605 = "stablehlo.gather"(%190, %604) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<35x2x1xf32>, tensor) -> tensor + %606 = shape.shape_of %599 : tensor -> tensor<2xindex> + %607 = shape.shape_of %605 : tensor -> tensor<2xindex> + %608 = shape.cstr_broadcastable %606, %607 : tensor<2xindex>, tensor<2xindex> + %609 = shape.assuming %608 -> (tensor) { + %695 = shape.broadcast %606, %607 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %599, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %605, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %610 = shape.shape_of %609 : tensor -> tensor<2xindex> + %611 = stablehlo.dynamic_broadcast_in_dim %609, %610, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %612 = stablehlo.dynamic_broadcast_in_dim %201, %610, dims = [] : (tensor, tensor<2xindex>) -> tensor + %613 = stablehlo.multiply %611, %612 : tensor + %dim_229 = tensor.dim %568, %c0 : tensor + %614 = arith.index_cast %dim_229 : index to i64 + %dim_230 = tensor.dim %609, %c0 : tensor + %615 = arith.index_cast %dim_230 : index to i64 + %616 = arith.maxsi %614, %615 : i64 + %617 = arith.index_cast %616 : i64 to index + %from_elements_231 = tensor.from_elements %617, %c32 : tensor<2xindex> + %618 = stablehlo.dynamic_broadcast_in_dim %568, %from_elements_231, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_232 = tensor.dim %618, %c0 : tensor + %619 = arith.index_cast %dim_232 : index to i64 + %from_elements_233 = tensor.from_elements %619, %c32_i64 : tensor<2xi64> + %620 = stablehlo.real_dynamic_slice %613, %c_43, %from_elements_233, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_234 = tensor.from_elements %619, %c32_i64, %c1_i64 : tensor<3xi64> + %621 = stablehlo.dynamic_reshape %618, %from_elements_234 : (tensor, tensor<3xi64>) -> tensor + %622 = stablehlo.dynamic_iota %from_elements_234, dim = 1 : (tensor<3xi64>) -> tensor + %623 = stablehlo.concatenate %621, %622, dim = 2 : (tensor, tensor) -> tensor + %624 = "stablehlo.scatter"(%556, %623, %620) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %695 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %695 : tensor + }) : (tensor<35x32xf32>, tensor, tensor) -> tensor<35x32xf32> + %625 = stablehlo.slice %125 [7:8, 0:2, 0:35] : (tensor<8x2x35xi64>) -> tensor<1x2x35xi64> + %626 = stablehlo.reshape %625 : (tensor<1x2x35xi64>) -> tensor<2x35xi64> + %627 = stablehlo.custom_call @byteir.non_zero(%626) {byteir_attrs = {}} : (tensor<2x35xi64>) -> tensor + %dim_235 = tensor.dim %627, %c0 : tensor + %628 = arith.index_cast %dim_235 : index to i64 + %from_elements_236 = tensor.from_elements %628, %c1_i64 : tensor<2xi64> + %629 = stablehlo.real_dynamic_slice %627, %c_43, %from_elements_236, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_237 = tensor.dim %629, %c0 : tensor + %630 = arith.index_cast %dim_237 : index to i64 + %from_elements_238 = tensor.from_elements %630 : tensor<1xi64> + %631 = stablehlo.dynamic_reshape %629, %from_elements_238 : (tensor, tensor<1xi64>) -> tensor + %from_elements_239 = tensor.from_elements %628, %c2_i64 : tensor<2xi64> + %632 = stablehlo.real_dynamic_slice %627, %c_46, %from_elements_239, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %dim_240 = tensor.dim %632, %c0 : tensor + %633 = arith.index_cast %dim_240 : index to i64 + %from_elements_241 = tensor.from_elements %633 : tensor<1xi64> + %634 = stablehlo.dynamic_reshape %632, %from_elements_241 : (tensor, tensor<1xi64>) -> tensor + %dim_242 = tensor.dim %634, %c0 : tensor + %635 = arith.index_cast %dim_242 : index to i64 + %from_elements_243 = tensor.from_elements %635, %c1_i64 : tensor<2xi64> + %636 = stablehlo.dynamic_reshape %634, %from_elements_243 : (tensor, tensor<2xi64>) -> tensor + %dim_244 = tensor.dim %636, %c0 : tensor + %637 = arith.index_cast %dim_244 : index to i64 + %from_elements_245 = tensor.from_elements %c1_i64, %637, %c32_i64 : tensor<3xi64> + %638 = stablehlo.dynamic_broadcast_in_dim %158, %from_elements_245, dims = [0, 1, 2] : (tensor<1x1x1xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_246 = tensor.dim %638, %c1 : tensor<1x?x32xi64> + %639 = arith.index_cast %dim_246 : index to i64 + %from_elements_247 = tensor.from_elements %c1_i64, %639, %c32_i64, %c1_i64 : tensor<4xi64> + %640 = stablehlo.dynamic_reshape %638, %from_elements_247 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %641 = stablehlo.dynamic_broadcast_in_dim %636, %from_elements_245, dims = [1, 2] : (tensor, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_248 = tensor.dim %641, %c1 : tensor<1x?x32xi64> + %642 = arith.index_cast %dim_248 : index to i64 + %from_elements_249 = tensor.from_elements %c1_i64, %642, %c32_i64, %c1_i64 : tensor<4xi64> + %643 = stablehlo.dynamic_reshape %641, %from_elements_249 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %644 = stablehlo.dynamic_broadcast_in_dim %145, %from_elements_245, dims = [2] : (tensor<32xi64>, tensor<3xi64>) -> tensor<1x?x32xi64> + %dim_250 = tensor.dim %644, %c1 : tensor<1x?x32xi64> + %645 = arith.index_cast %dim_250 : index to i64 + %from_elements_251 = tensor.from_elements %c1_i64, %645, %c32_i64, %c1_i64 : tensor<4xi64> + %646 = stablehlo.dynamic_reshape %644, %from_elements_251 : (tensor<1x?x32xi64>, tensor<4xi64>) -> tensor<1x?x32x1xi64> + %647 = stablehlo.concatenate %640, %643, %646, dim = 3 : (tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>, tensor<1x?x32x1xi64>) -> tensor<1x?x32x3xi64> + %648 = "stablehlo.gather"(%136, %647) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<1x35x32xf32>, tensor<1x?x32x3xi64>) -> tensor<1x?x32xf32> + %649 = shape.shape_of %648 : tensor<1x?x32xf32> -> tensor<3xindex> + %650 = shape.num_elements %649 : tensor<3xindex> -> index + %651 = stablehlo.compute_reshape_shape %650, %c_45 : (index, tensor<2xi64>) -> tensor<2xi64> + %652 = stablehlo.dynamic_reshape %648, %651 : (tensor<1x?x32xf32>, tensor<2xi64>) -> tensor + %653 = stablehlo.transpose %cst_13, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %654 = stablehlo.dot %652, %653 : (tensor, tensor<32x14336xf32>) -> tensor + %655 = stablehlo.logistic %654 : tensor + %656 = shape.shape_of %655 : tensor -> tensor<2xindex> + %657 = shape.shape_of %654 : tensor -> tensor<2xindex> + %658 = shape.cstr_broadcastable %656, %657 : tensor<2xindex>, tensor<2xindex> + %659 = shape.assuming %658 -> (tensor) { + %695 = shape.broadcast %656, %657 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %655, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %654, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %660 = stablehlo.transpose %cst_12, dims = [1, 0] : (tensor<14336x32xf32>) -> tensor<32x14336xf32> + %661 = stablehlo.dot %652, %660 : (tensor, tensor<32x14336xf32>) -> tensor + %662 = shape.shape_of %659 : tensor -> tensor<2xindex> + %663 = shape.shape_of %661 : tensor -> tensor<2xindex> + %664 = shape.cstr_broadcastable %662, %663 : tensor<2xindex>, tensor<2xindex> + %665 = shape.assuming %664 -> (tensor) { + %695 = shape.broadcast %662, %663 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %659, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %661, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %666 = stablehlo.transpose %cst_11, dims = [1, 0] : (tensor<32x14336xf32>) -> tensor<14336x32xf32> + %667 = stablehlo.dot %665, %666 : (tensor, tensor<14336x32xf32>) -> tensor + %dim_252 = tensor.dim %634, %c0 : tensor + %668 = arith.index_cast %dim_252 : index to i64 + %from_elements_253 = tensor.from_elements %668, %c1_i64 : tensor<2xi64> + %669 = stablehlo.dynamic_reshape %634, %from_elements_253 : (tensor, tensor<2xi64>) -> tensor + %dim_254 = tensor.dim %631, %c0 : tensor + %670 = arith.index_cast %dim_254 : index to i64 + %from_elements_255 = tensor.from_elements %670, %c1_i64 : tensor<2xi64> + %671 = stablehlo.dynamic_reshape %631, %from_elements_255 : (tensor, tensor<2xi64>) -> tensor + %672 = stablehlo.concatenate %669, %671, dim = 1 : (tensor, tensor) -> tensor + %673 = "stablehlo.gather"(%190, %672) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<35x2x1xf32>, tensor) -> tensor + %674 = shape.shape_of %667 : tensor -> tensor<2xindex> + %675 = shape.shape_of %673 : tensor -> tensor<2xindex> + %676 = shape.cstr_broadcastable %674, %675 : tensor<2xindex>, tensor<2xindex> + %677 = shape.assuming %676 -> (tensor) { + %695 = shape.broadcast %674, %675 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %696 = stablehlo.dynamic_broadcast_in_dim %667, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %697 = stablehlo.dynamic_broadcast_in_dim %673, %695, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %698 = stablehlo.multiply %696, %697 : tensor + shape.assuming_yield %698 : tensor + } + %678 = shape.shape_of %677 : tensor -> tensor<2xindex> + %679 = stablehlo.dynamic_broadcast_in_dim %677, %678, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %680 = stablehlo.dynamic_broadcast_in_dim %201, %678, dims = [] : (tensor, tensor<2xindex>) -> tensor + %681 = stablehlo.multiply %679, %680 : tensor + %dim_256 = tensor.dim %636, %c0 : tensor + %682 = arith.index_cast %dim_256 : index to i64 + %dim_257 = tensor.dim %677, %c0 : tensor + %683 = arith.index_cast %dim_257 : index to i64 + %684 = arith.maxsi %682, %683 : i64 + %685 = arith.index_cast %684 : i64 to index + %from_elements_258 = tensor.from_elements %685, %c32 : tensor<2xindex> + %686 = stablehlo.dynamic_broadcast_in_dim %636, %from_elements_258, dims = [0, 1] : (tensor, tensor<2xindex>) -> tensor + %dim_259 = tensor.dim %686, %c0 : tensor + %687 = arith.index_cast %dim_259 : index to i64 + %from_elements_260 = tensor.from_elements %687, %c32_i64 : tensor<2xi64> + %688 = stablehlo.real_dynamic_slice %681, %c_43, %from_elements_260, %c_44 : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + %from_elements_261 = tensor.from_elements %687, %c32_i64, %c1_i64 : tensor<3xi64> + %689 = stablehlo.dynamic_reshape %686, %from_elements_261 : (tensor, tensor<3xi64>) -> tensor + %690 = stablehlo.dynamic_iota %from_elements_261, dim = 1 : (tensor<3xi64>) -> tensor + %691 = stablehlo.concatenate %689, %690, dim = 2 : (tensor, tensor) -> tensor + %692 = "stablehlo.scatter"(%624, %691, %688) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %695 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %695 : tensor + }) : (tensor<35x32xf32>, tensor, tensor) -> tensor<35x32xf32> + %693 = stablehlo.reshape %692 : (tensor<35x32xf32>) -> tensor<5x7x32xf32> + %694 = stablehlo.add %86, %693 : tensor<5x7x32xf32> + return %694 : tensor<5x7x32xf32> + } +} + +{-# + dialect_resources: { + builtin: { + torch_tensor_32_torch.float32_1: "0x040000000000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F", + torch_tensor_32_torch.float32: "0x040000000000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F0000803F" + } + } +#-} + diff --git a/frontends/torch-frontend/third_party/patches/fx_importer.patch b/frontends/torch-frontend/third_party/patches/fx_importer.patch new file mode 100644 index 000000000..5b72dc6b0 --- /dev/null +++ b/frontends/torch-frontend/third_party/patches/fx_importer.patch @@ -0,0 +1,158 @@ +diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py +index 381f8f9a..a5e2d32b 100644 +--- a/python/torch_mlir/extras/fx_importer.py ++++ b/python/torch_mlir/extras/fx_importer.py +@@ -52,6 +52,10 @@ from torch._subclasses import ( + FakeTensor as TorchFakeTensor, + ) + ++from torch.distributed._functional_collectives import ( ++ AsyncCollectiveTensor as TorchAsyncCollectiveTensor ++) ++ + from torch.fx import ( + Graph, + GraphModule, +@@ -924,6 +928,19 @@ class ContextCache: + tensor_meta = node.meta.get("tensor_meta") + val = node.meta.get("val") + sparsity = node.meta.get("sparsity", None) ++ # Some nodes returns a list, like torch.ops.aten.unbind.int ++ if isinstance(tensor_meta, List) or isinstance(val, List): ++ if tensor_meta is not None and all(x is not None for x in tensor_meta): ++ # Assume that all results in the list are tensors. ++ # TODO: Solve this assumption ++ return IrType.parse("!torch.list", context=self._c) ++ elif val is not None and all(x is not None for x in val): ++ return IrType.parse("!torch.list", context=self._c) ++ else: ++ raise NotImplementedError( ++ f"FIXME: Unsupported placeholder node (this often indicates that a necessary) " ++ f"fx preprocessing pass was not run): {node.meta}" ++ ) + except KeyError as e: + raise RuntimeError( + f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" +@@ -1035,6 +1052,7 @@ class GraphNodeImporter: + "_on_node_produced", + "_v", + "_multi_result_nodes", ++ "_list_return_nodes", + "fx_importer", + ] + +@@ -1058,6 +1076,9 @@ class GraphNodeImporter: + # They will have their getitem calls short-circuited. + self._multi_result_nodes: Set[torch_fx.Node] = set() + ++ # Stores the node that returns a list, like aten.unbind.int ++ self._list_return_nodes: Set[torch_fx.Node] = set() ++ + def bind_node_value( + self, + node: Node, +@@ -1213,6 +1234,23 @@ class GraphNodeImporter: + f"notify developers if this case happens " + f"(at {loc})." + ) ++ elif getitem_ref in self._list_return_nodes: ++ fx_list_return_value = self._v[(getitem_ref, 0)] ++ operands = [ ++ fx_list_return_value, ++ self._import_default_value(loc, getitem_index, torch.IntType) ++ ] ++ ++ # We trust the tensor type in FX graph, even if it's a getitem ++ # from a value of MLIR ListType. ++ operation = Operation.create( ++ "torch.aten.__getitem__.t", ++ results=(self._cc.node_val_to_type(node),), ++ operands = operands, ++ loc=loc ++ ) ++ for i, value in enumerate(operation.results): ++ self._v[(node, i)] = value + else: + raise NotImplementedError( + f"General getitem access to non-multi-result ops" +@@ -1676,6 +1714,10 @@ class GraphNodeImporter: + # Unary return directly maps a single meta["val"] and cannot be subscripted. + # if "tensor_meta" is None, this will throw unsupported placeholder node error + result_types = [self._cc.node_val_to_type(node)] ++ ++ # separately handle ops returning list. ++ if str(result_types[0]).startswith("!torch.list"): ++ self._list_return_nodes.add(node) + elif return_count == 0: + # Some torch ops do have 0 returns, and these are supported with ZeroResults + # op trait. Python bindings for IR creation allow us to pass empty result_types +@@ -1717,6 +1759,8 @@ def _make_vtensor_literal_op( + ) -> Operation: + mapping = py_attr_tracker.track(tensor) + if mapping.is_empty: ++ # unwrap from TorchAsyncCollectiveTensor ++ tensor = tensor.elem if isinstance(tensor, TorchAsyncCollectiveTensor) else tensor + # check support for bfloat16 + assert not ( + tensor.dtype == torch.bfloat16 and ml_dtypes is None +@@ -1732,29 +1776,42 @@ def _make_vtensor_literal_op( + # detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw + # buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as + # desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above) +- np_tensor = np.array(tensor.tolist()).astype(npy_dtype) +- # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not +- # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling +- # 0d tensors. +- if np_tensor.size == 1: ++ ++ # NOTE: if we torch.export a torch.nn.Module under fake mode, the parameters in the fx.GraphModule will be FakeTensor. ++ # So we specifically handle FakeTensor here by creating a splat DenseElementsAttr with value 0. ++ if isinstance(tensor, TorchFakeTensor): ++ array = np.array([0]).astype(npy_dtype) + try: +- dtype = tensor.dtype +- element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() ++ element_type = TORCH_DTYPE_TO_MLIR_TYPE[tensor.dtype]() + except KeyError: +- raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") ++ raise TypeError(f"Could not map Torch dtype {tensor.dtype} to an MLIR type") + elements_attr = DenseElementsAttr.get( +- type=element_type, array=np_tensor, shape=np_tensor.shape ++ array=array, type=element_type, shape=list(tensor.shape) + ) + else: +- bytes_view = np_tensor.view(npy_dtype) +- tensor_type = create_mlir_tensor_type(tensor) +- shape_desc = "_".join([str(d) for d in tensor.shape]) +- blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" +- elements_attr = DenseResourceElementsAttr.get_from_buffer( +- bytes_view, +- blob_name, +- tensor_type, +- ) ++ np_tensor = np.array(tensor.tolist()).astype(npy_dtype) ++ # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not ++ # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling ++ # 0d tensors. ++ if np_tensor.size == 1: ++ try: ++ dtype = tensor.dtype ++ element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() ++ except KeyError: ++ raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") ++ elements_attr = DenseElementsAttr.get( ++ type=element_type, array=np_tensor, shape=np_tensor.shape ++ ) ++ else: ++ bytes_view = np_tensor.view(npy_dtype) ++ tensor_type = create_mlir_tensor_type(tensor) ++ shape_desc = "_".join([str(d) for d in tensor.shape]) ++ blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" ++ elements_attr = DenseResourceElementsAttr.get_from_buffer( ++ bytes_view, ++ blob_name, ++ tensor_type, ++ ) + mapping.value = elements_attr + else: + elements_attr = mapping.value diff --git a/frontends/torch-frontend/third_party/patches/fx_list_return.patch b/frontends/torch-frontend/third_party/patches/fx_list_return.patch deleted file mode 100644 index 1e1129c9b..000000000 --- a/frontends/torch-frontend/third_party/patches/fx_list_return.patch +++ /dev/null @@ -1,77 +0,0 @@ -diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py -index aee8251b..d157225a 100644 ---- a/python/torch_mlir/extras/fx_importer.py -+++ b/python/torch_mlir/extras/fx_importer.py -@@ -927,6 +927,19 @@ class ContextCache: - tensor_meta = node.meta.get("tensor_meta") - val = node.meta.get("val") - sparsity = node.meta.get("sparsity", None) -+ # Some nodes returns a list, like torch.ops.aten.unbind.int -+ if isinstance(tensor_meta, List) or isinstance(val, List): -+ if tensor_meta is not None and all(x is not None for x in tensor_meta): -+ # Assume that all results in the list are tensors. -+ # TODO: Solve this assumption -+ return IrType.parse("!torch.list", context=self._c) -+ elif val is not None and all(x is not None for x in val): -+ return IrType.parse("!torch.list", context=self._c) -+ else: -+ raise NotImplementedError( -+ f"FIXME: Unsupported placeholder node (this often indicates that a necessary) " -+ f"fx preprocessing pass was not run): {node.meta}" -+ ) - except KeyError as e: - raise RuntimeError( - f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" -@@ -1038,6 +1051,7 @@ class GraphNodeImporter: - "_on_node_produced", - "_v", - "_multi_result_nodes", -+ "_list_return_nodes", - "fx_importer", - ] - -@@ -1061,6 +1075,9 @@ class GraphNodeImporter: - # They will have their getitem calls short-circuited. - self._multi_result_nodes: Set[torch_fx.Node] = set() - -+ # Stores the node that returns a list, like aten.unbind.int -+ self._list_return_nodes: Set[torch_fx.Node] = set() -+ - def bind_node_value( - self, - node: Node, -@@ -1216,6 +1233,23 @@ class GraphNodeImporter: - f"notify developers if this case happens " - f"(at {loc})." - ) -+ elif getitem_ref in self._list_return_nodes: -+ fx_list_return_value = self._v[(getitem_ref, 0)] -+ operands = [ -+ fx_list_return_value, -+ self._import_default_value(loc, getitem_index, torch.IntType) -+ ] -+ -+ # We trust the tensor type in FX graph, even if it's a getitem -+ # from a value of MLIR ListType. -+ operation = Operation.create( -+ "torch.aten.__getitem__.t", -+ results=(self._cc.node_val_to_type(node),), -+ operands = operands, -+ loc=loc -+ ) -+ for i, value in enumerate(operation.results): -+ self._v[(node, i)] = value - else: - raise NotImplementedError( - f"General getitem access to non-multi-result ops" -@@ -1642,6 +1676,10 @@ class GraphNodeImporter: - # Unary return directly maps a single meta["val"] and cannot be subscripted. - # if "tensor_meta" is None, this will throw unsupported placeholder node error - result_types = [self._cc.node_val_to_type(node)] -+ -+ # separately handle ops returning list. -+ if str(result_types[0]).startswith("!torch.list"): -+ self._list_return_nodes.add(node) - elif return_count == 0: - # Some torch ops do have 0 returns, and these are supported with ZeroResults - # op trait. Python bindings for IR creation allow us to pass empty result_types diff --git a/frontends/torch-frontend/third_party/patches/gather_scatter.patch b/frontends/torch-frontend/third_party/patches/gather_scatter.patch new file mode 100644 index 000000000..d37d7cdff --- /dev/null +++ b/frontends/torch-frontend/third_party/patches/gather_scatter.patch @@ -0,0 +1,283 @@ +diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +index 734ba81e..03edf5d9 100644 +--- a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h ++++ b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +@@ -49,7 +49,7 @@ Value promoteType(PatternRewriter &rewriter, Location loc, Value input, + TensorType outType); + + Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, +- TensorType outType); ++ TensorType outType, std::optional bcastSizeTensor); + + SmallVector toPositiveDims(ArrayRef dims, int64_t rank); + +diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp +index d01f0daf..a5b1b28b 100644 +--- a/lib/Conversion/TorchToStablehlo/Basic.cpp ++++ b/lib/Conversion/TorchToStablehlo/Basic.cpp +@@ -766,7 +766,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( + getTypeConverter()->convertType(op->getResult(0).getType())); + + if (options.enableStaticShape && selfTy.hasStaticShape()) { +- Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType); ++ Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType, std::nullopt); + rewriter.replaceOp(op, bcastOp); + return success(); + } +@@ -1457,8 +1457,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( + .value()); + + // Apply affine transform: output x weight + bias [element-wise] +- auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy); +- auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy); ++ auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy, std::nullopt); ++ auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy, std::nullopt); + auto outputMulWeight = + rewriter.create(op->getLoc(), output, bcastedWeight); + auto finalOuput = rewriter.create( +@@ -1603,8 +1603,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( + maxValue = *maxInfo; + } + if (inputType.hasStaticShape()) { +- minValue = hlo::promoteAndBroadcast(rewriter, minValue, inputType); +- maxValue = hlo::promoteAndBroadcast(rewriter, maxValue, inputType); ++ minValue = hlo::promoteAndBroadcast(rewriter, minValue, inputType, std::nullopt); ++ maxValue = hlo::promoteAndBroadcast(rewriter, maxValue, inputType, std::nullopt); + } + rewriter.replaceOpWithNewOp(op, minValue, input, + maxValue); +@@ -2016,7 +2016,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( + + auto resultType = + cast(getTypeConverter()->convertType(op.getType())); +- rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType); ++ rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType, std::nullopt); + rewriter.replaceOpWithNewOp(op, lhs, rhs); + return success(); + } +@@ -2031,7 +2031,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( + + auto resultType = + cast(getTypeConverter()->convertType(op.getType())); +- rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType); ++ rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType, std::nullopt); + rewriter.replaceOpWithNewOp(op, lhs, rhs); + return success(); + } +diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +index 00c022cc..6f4a503a 100644 +--- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp ++++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +@@ -194,6 +194,88 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, + } // namespace + + namespace { ++ ++// A helper function to generate the final shape tensor if we broadcast ++// `tensors`. ++FailureOr getBroadcastSize(Operation *op, ++ ConversionPatternRewriter &rewriter, ++ SmallVector tensors, ++ size_t dimSizeIndexBits) { ++ SmallVector> tensorSizes; ++ ++ int maxRank = 0; ++ for (auto tensor : tensors) { ++ auto tensorType = cast(tensor.getType()); ++ auto tensorRank = tensorType.getRank(); ++ ++ tensorSizes.emplace_back(tensorType.getShape()); ++ maxRank = std::max(maxRank, static_cast(tensorRank)); ++ } ++ ++ SmallVector bcastSizeTensors; ++ for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions. ++ int dynamicDimCnt = 0; ++ int staticDimCnt = 0; ++ int64_t staticDimSize; ++ Value dimSizeTensor = rewriter.create( ++ op->getLoc(), ++ rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1)); ++ ++ for (size_t i = 0; i < tensorSizes.size(); ++i) { // loop tensors. ++ int inDim = tensorSizes[i].size() - 1 - outDim; ++ if (inDim < 0) ++ continue; ++ ++ // dim size 1 ++ if (tensorSizes[i][inDim] == 1) ++ continue; ++ // dim size dynamic ++ if (tensorSizes[i][inDim] == ShapedType::kDynamic || ++ tensorSizes[i][inDim] == kUnknownSize) { ++ dynamicDimCnt++; ++ auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( ++ rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); ++ if (failed(dimSizeTensorInfo)) { ++ llvm::outs() << "failed to generate tensor size\n"; ++ return failure(); ++ } ++ dimSizeTensor = (*dimSizeTensorInfo)[0]; ++ continue; ++ } ++ // dim size static ++ // we already found dynamic dim size, fail. ++ if (dynamicDimCnt > 0) { ++ llvm::outs() << "multi dynamic\n"; ++ return failure(); ++ } ++ // we already found static dim size not equal with this, fail. ++ if (staticDimCnt > 0 && staticDimSize != tensorSizes[i][inDim]) { ++ llvm::outs() << "unsame static dim size\n"; ++ return failure(); ++ } ++ ++ staticDimCnt++; ++ staticDimSize = tensorSizes[i][inDim]; ++ auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( ++ rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); ++ if (failed(dimSizeTensorInfo)) { ++ llvm::outs() << "failed to generate tensor size 2\n"; ++ return failure(); ++ } ++ dimSizeTensor = (*dimSizeTensorInfo)[0]; ++ } ++ // Relax this check, by assuming all dynamic shape is same. ++ // if (dynamicDimCnt > 1) { ++ // return failure(); ++ // } ++ ++ bcastSizeTensors.push_back(dimSizeTensor); ++ } ++ std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end()); ++ return rewriter.create(op->getLoc(), bcastSizeTensors) ++ .getResult(); ++} ++ + // A helper function used to generate stablehlo's ScatterIndices or + // GatherIndices from torch's indices, usually appear in torch ops, like + // aten.index.Tensor or aten.input_put A usage example is as follow: Input: [[1, +@@ -216,28 +298,38 @@ FailureOr broadcastAndConcatIndices(Operation *op, + ConversionPatternRewriter &rewriter, + SmallVector indexTensors, + llvm::ArrayRef inputShape, ++ size_t dimSizeIndexBits, + int &maxIndexRank) { + // Step 1: broadcast indices tensors + SmallVector indicesShape; + SmallVector expandShape; + SmallVector concatShape; ++ ++ bool allIndexStaticShape = true; ++ Value bcastSizeTensor; ++ + // concat index tensor into to indices tensor for concat + for (size_t i = 0; i < indexTensors.size(); i++) { + auto indexTensor = indexTensors[i]; + auto indexTensorType = cast(indexTensor.getType()); + for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) { + if (size == kUnknownSize) +- return failure(); ++ allIndexStaticShape = false; + } + maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank()); + } + +- SmallVector refinedInputShape = makeShapeTorchCompatible(inputShape); +- for (int64_t size : refinedInputShape) { +- if (size == kUnknownSize) { ++ if (!allIndexStaticShape) { ++ auto bcastSizeTensorInfo = ++ getBroadcastSize(op, rewriter, indexTensors, dimSizeIndexBits); ++ if (failed(bcastSizeTensorInfo)) { ++ llvm::outs() << "failed here\n"; + return failure(); + } ++ bcastSizeTensor = *bcastSizeTensorInfo; + } ++ ++ llvm::ArrayRef refinedInputShape = inputShape; + for (int i = 0; i < maxIndexRank; i++) { + indicesShape.push_back(refinedInputShape[i]); + expandShape.push_back(refinedInputShape[i]); +@@ -252,12 +344,27 @@ FailureOr broadcastAndConcatIndices(Operation *op, + RankedTensorType bcastIndexType = + RankedTensorType::get(indicesShape, indexElemTy); + for (auto indexTensor : indexTensors) { +- Value bcastVal = +- hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType); ++ Value bcastVal; + RankedTensorType reshapeType = + RankedTensorType::get(expandShape, indexElemTy); +- bcastVal = rewriter.create(op->getLoc(), reshapeType, +- bcastVal); ++ if (allIndexStaticShape) { ++ bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType, ++ std::nullopt); ++ bcastVal = rewriter.create(op->getLoc(), ++ reshapeType, bcastVal); ++ } else { ++ bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType, ++ bcastSizeTensor); ++ auto bcastValShapeTensorVec = ++ *hlo::getDimSizesOfTensor(rewriter, op, bcastVal, dimSizeIndexBits); ++ bcastValShapeTensorVec.push_back(rewriter.create( ++ op->getLoc(), rewriter.getIntegerAttr( ++ rewriter.getIntegerType(dimSizeIndexBits), 1))); ++ Value bcastValShapeTensor = rewriter.create( ++ op->getLoc(), bcastValShapeTensorVec).getResult(); ++ bcastVal = rewriter.create( ++ op->getLoc(), reshapeType, bcastVal, bcastValShapeTensor); ++ } + broadcastedIndices.push_back(bcastVal); + } + +@@ -803,7 +910,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( + + int maxIndexRank = -1; + auto gatherIndicesInfo = broadcastAndConcatIndices(op, rewriter, indexTensors, +- outShape, maxIndexRank); ++ outShape, options.dimSizeIndexBits, maxIndexRank); + if (failed(gatherIndicesInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to generate broadcasted indices"); +@@ -877,7 +984,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( + + int maxIndexRank = -1; + auto scatterIndicesInfo = broadcastAndConcatIndices( +- op, rewriter, indexTensors, valuesShape, maxIndexRank); ++ op, rewriter, indexTensors, valuesShape, options.dimSizeIndexBits, maxIndexRank); + if (failed(scatterIndicesInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to generate broadcasted indices"); +diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +index c4d629d4..5332c204 100644 +--- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp ++++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +@@ -156,12 +156,14 @@ Value promoteType(PatternRewriter &rewriter, Location loc, Value input, + } + + Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, +- TensorType outType) { ++ TensorType outType, std::optional bcastSizeTensor) { + // Two tensors are “broadcastable” if the following rules hold: + // - Each tensor has at least one dimension. + // - When iterating over the dimension sizes, starting at the trailing + // dimension, the dimension sizes must either be equal, one of them is 1, or + // one of them does not exist. ++ // If one provide bcastSizeTensor, we emit stablehlo::DynamicBroadcastInDimOp instead ++ // of stablehlo::BroadcastInDimOp to support dynamic shape. + Operation *op = input.getDefiningOp(); + TensorType in_type = dyn_cast(input.getType()); + +@@ -199,6 +201,10 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, + return input; + } + auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims); ++ if (bcastSizeTensor.has_value()) { ++ auto bcast_op = rewriter.create(op->getLoc(), outType, input, bcastSizeTensor.value(), bcast_attr); ++ return bcast_op.getResult(); ++ } + auto bcast_op = rewriter.create( + op->getLoc(), outType, input, bcast_attr); + return bcast_op.getResult(); diff --git a/frontends/torch-frontend/third_party/patches/index_add.patch b/frontends/torch-frontend/third_party/patches/index_add.patch new file mode 100644 index 000000000..c74fdb5de --- /dev/null +++ b/frontends/torch-frontend/third_party/patches/index_add.patch @@ -0,0 +1,260 @@ +commit 4e5577ad88fc99b93eec0ede85e61ad5c7a87e99 +Author: wujiawei.aml +Date: Wed May 8 22:36:24 2024 +0800 + + [torch-dialect] emit aten.index_add and decompose it to scatter.add op + +diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +index 4de41e13..8734896c 100644 +--- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td ++++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +@@ -5703,6 +5703,59 @@ def Torch_AtenTril_Op : Torch_Op<"aten.tril_", [ + }]; + } + ++def Torch_AtenIndexAddOp : Torch_Op<"aten.index_add", [ ++ AllowsTypeRefinement, ++ HasValueSemantics, ++ ReadOnly ++ ]> { ++ let summary = "Generated op for `aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`"; ++ let arguments = (ins ++ AnyTorchTensorType:$self, ++ Torch_IntType:$dim, ++ AnyTorchTensorType:$index, ++ AnyTorchTensorType:$source, ++ AnyTorchScalarType:$alpha ++ ); ++ let results = (outs ++ AnyTorchTensorType:$result ++ ); ++ let hasCustomAssemblyFormat = 1; ++ let extraClassDefinition = [{ ++ ParseResult AtenIndexAddOp::parse(OpAsmParser &parser, OperationState &result) { ++ return parseDefaultTorchOp(parser, result, 5, 1); ++ } ++ void AtenIndexAddOp::print(OpAsmPrinter &printer) { ++ printDefaultTorchOp(printer, *this, 5, 1); ++ } ++ }]; ++} ++ ++def Torch_AtenIndexAdd_Op : Torch_Op<"aten.index_add_", [ ++ IsTrailingUnderscoreInplaceVariant, ++ AllowsTypeRefinement ++ ]> { ++ let summary = "Generated op for `aten::index_add_ : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`"; ++ let arguments = (ins ++ Torch_NonValueTensorType:$self, ++ Torch_IntType:$dim, ++ Torch_NonValueTensorType:$index, ++ Torch_NonValueTensorType:$source, ++ AnyTorchScalarType:$alpha ++ ); ++ let results = (outs ++ Torch_NonValueTensorType:$result ++ ); ++ let hasCustomAssemblyFormat = 1; ++ let extraClassDefinition = [{ ++ ParseResult AtenIndexAdd_Op::parse(OpAsmParser &parser, OperationState &result) { ++ return parseDefaultTorchOp(parser, result, 5, 1); ++ } ++ void AtenIndexAdd_Op::print(OpAsmPrinter &printer) { ++ printDefaultTorchOp(printer, *this, 5, 1); ++ } ++ }]; ++} ++ + def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [ + AllowsTypeRefinement, + HasValueSemantics, +diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +index 43bcc3ac..a414cca7 100644 +--- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp ++++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +@@ -9185,6 +9185,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { + " %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list\n" + " return %0 : !torch.list\n" + " }\n" ++" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.list) -> !torch.list {\n" ++" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" ++" return %0 : !torch.list\n" ++" }\n" ++" func.func @\"__torch_mlir_shape_fn.aten.index_add\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.float) -> !torch.list {\n" ++" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" ++" return %0 : !torch.list\n" ++" }\n" + " func.func @\"__torch_mlir_shape_fn.aten.index_put\"(%arg0: !torch.list, %arg1: !torch.list>>, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" + " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" + " return %0 : !torch.list\n" +@@ -10399,6 +10407,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { + " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" + " return %0#1 : !torch.int\n" + " }\n" ++" func.func @\"__torch_mlir_dtype_fn.aten.index_add\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.number) -> !torch.int {\n" ++" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" ++" return %0#1 : !torch.int\n" ++" }\n" ++" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple) -> !torch.int {\n" ++" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" ++" return %0#1 : !torch.int\n" ++" }\n" + " func.func @\"__torch_mlir_dtype_fn.aten.index_select\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple) -> !torch.int {\n" + " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" + " return %0#1 : !torch.int\n" +diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +index 5ec22233..e5a6b2fe 100644 +--- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp ++++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +@@ -5621,6 +5621,75 @@ public: + }; + } // namespace + ++namespace { ++// Decompose `aten.index_add` op into `aten.index_put` ++class DecomposeAtenIndexAddOp : public OpRewritePattern { ++public: ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(AtenIndexAddOp op, ++ PatternRewriter &rewriter) const override { ++ Location loc = op.getLoc(); ++ Value src = op.getSource(); ++ Value input = op.getSelf(); ++ Value index = op.getIndex(); ++ Value alpha = op.getAlpha(); ++ ++ int64_t dim; ++ if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { ++ return rewriter.notifyMatchFailure(op, ++ "dim of index_add must be a constant"); ++ } ++ std::optional maybeInputRank = getTensorRank(input); ++ if (!maybeInputRank) { ++ return rewriter.notifyMatchFailure(op, "expected input to have a rank"); ++ } ++ int64_t inputRank = static_cast(*maybeInputRank); ++ dim = toPositiveDim(dim, inputRank); ++ if (!isValidDim(dim, inputRank)) { ++ return rewriter.notifyMatchFailure(op, "index dim is not a valid dim"); ++ } ++ ++ auto resType = op.getType().cast(); ++ auto srcType = src.getType().cast(); ++ auto indexType = index.getType().cast(); ++ if (!indexType.hasDtype()) { ++ return rewriter.notifyMatchFailure(op, "index should have dtype"); ++ } ++ auto indexDtype = indexType.getDtype(); ++ ++ // calculate src * alpha first. ++ Value newSrc = ++ rewriter.create(loc, srcType, src, alpha); ++ ++ // broadcast index to have the same shape as src. ++ Value constMinusOne = rewriter.create( ++ loc, rewriter.getI64IntegerAttr(-1)); ++ for (int64_t i = dim + 1; i < inputRank; ++i) { ++ index = *unsqueezeTensor(rewriter, op, index, /*dim=*/constMinusOne); ++ } ++ ++ SmallVector bcastShape; ++ SmallVector bcastShapeValue; ++ computeBroadcastShape(rewriter, loc, index, src, bcastShape, ++ bcastShapeValue); ++ ++ Type bcastType = ValueTensorType::get( ++ op.getContext(), llvm::ArrayRef(bcastShape), indexDtype); ++ ++ Value indexBcastShapeTorchList = rewriter.create( ++ loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), ++ bcastShapeValue); ++ ++ index = rewriter.create(loc, bcastType, index, ++ indexBcastShapeTorchList); ++ ++ rewriter.replaceOpWithNewOp(op, resType, input, ++ op.getDim(), index, newSrc); ++ return success(); ++ } ++}; ++} // namespace ++ + namespace { + class DecomposeAtenExpandAsOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; +@@ -8021,6 +8090,7 @@ public: + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); ++ addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); +diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +index 0ca7ea9c..39bb88e5 100644 +--- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp ++++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +@@ -471,6 +471,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); ++ target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); +diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +index 1cf0c2c7..1d7fd7b3 100644 +--- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py ++++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +@@ -1607,15 +1607,18 @@ def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int], + def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]: + return upstream_shape_functions.index_select(self, dim, index) + ++def aten〇scatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]: ++ return upstream_shape_functions.unary(self) ++ ++def aten〇index_add〡shape(self: List[int], dim: int, index: List[int], source: List[int], alpha: float = 1) -> List[int]: ++ return upstream_shape_functions.unary(self) ++ + def aten〇index_put〡shape(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False) -> List[int]: + return upstream_shape_functions.unary(self) + + def aten〇index_put〇hacked_twin〡shape(self: List[int], indices: List[List[int]], values: List[int], accumulate: bool = False) -> List[int]: + return upstream_shape_functions.unary(self) + +-def aten〇embedding〡shape(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]: +- return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse) +- + def aten〇embedding_bag〇padding_idx〡shape(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]: + return _embedding_bag_helper(weight, indices, offsets, include_last_offset, + mode, per_sample_weights, padding_idx) +@@ -2534,6 +2537,16 @@ def aten〇index_put〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtyp + self_rank, self_dtype = self_rank_dtype + return self_dtype + ++@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) ++def aten〇index_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int: ++ self_rank, self_dtype = self_rank_dtype ++ return self_dtype ++ ++@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) ++def aten〇scatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int: ++ self_rank, self_dtype = self_rank_dtype ++ return self_dtype ++ + @check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, None, 0, TensorOfShape(1, dtype=torch.int64))) + def aten〇index_select〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype +diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +index c847e42d..e7fd7cf7 100644 +--- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py ++++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +@@ -516,6 +516,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): + + emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") + emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)") ++ emit_with_mutating_variants("aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)") ++ + emit_with_mutating_variants( + "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)" + ) diff --git a/frontends/torch-frontend/torch-frontend/lib/Transforms/EliminateUselessOp.cpp b/frontends/torch-frontend/torch-frontend/lib/Transforms/EliminateUselessOp.cpp index 848ea590f..4ffc785f8 100644 --- a/frontends/torch-frontend/torch-frontend/lib/Transforms/EliminateUselessOp.cpp +++ b/frontends/torch-frontend/torch-frontend/lib/Transforms/EliminateUselessOp.cpp @@ -59,6 +59,18 @@ struct EliminateAtenWarnOp : public OpRewritePattern { }; } // namespace +namespace { +// This is probably buggy. Maybe we should pass through the runtime assert +struct EliminateAtenRuntimeAssertOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(RuntimeAssertOp op, + PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + namespace { struct EliminateUselessOpPass : public EliminateUselessOpBase { @@ -70,6 +82,8 @@ struct EliminateUselessOpPass patterns.add(context, "profiler."); // Eliminate torch.aten.warn op patterns.add(context); + // Eliminate torch.runtime.assert op + patterns.add(context); FrozenRewritePatternSet frozenPatterns(std::move(patterns));